From b486de5d8cc7adda2bab23375afdb3950bbf487a Mon Sep 17 00:00:00 2001 From: a Date: Thu, 6 Jul 2023 09:59:17 +0100 Subject: [PATCH] add epislion greedy --- algo/epsilon_greedy.go | 39 +++++++++++++++++++ algo/mutex.go | 6 +++ algo/ucb.go | 12 +----- gambit.go | 6 ++- go.mod | 8 ++++ go.sum | 9 +++++ helper/storage.go | 87 +++++++++++++++++++++++++++--------------- 7 files changed, 125 insertions(+), 42 deletions(-) create mode 100644 algo/epsilon_greedy.go create mode 100644 go.sum diff --git a/algo/epsilon_greedy.go b/algo/epsilon_greedy.go new file mode 100644 index 0000000..d265bd4 --- /dev/null +++ b/algo/epsilon_greedy.go @@ -0,0 +1,39 @@ +package algo + +import ( + "lukechampine.com/frand" + "tuxpa.in/a/gambit/helper" +) + +type EpsilonGreedy struct { + Epsilon float64 + cr helper.CountReward +} + +func (u *EpsilonGreedy) Select(r float64) int { + if r > u.Epsilon { + return int(u.cr.RewardMax()) + } + return frand.Intn(u.cr.Size()) +} + +func (u *EpsilonGreedy) Update(a int, r float64) error { + return u.cr.Update(a, r) +} + +func (u *EpsilonGreedy) Reset(n int) error { + u.cr.ResetTo(n) + return nil +} + +func (u *EpsilonGreedy) Size() int { + return u.cr.Size() +} + +func (u *EpsilonGreedy) Count(res []int) { + u.cr.Count(res) +} + +func (u *EpsilonGreedy) Reward(res []float64) { + u.cr.Reward(res) +} diff --git a/algo/mutex.go b/algo/mutex.go index 42c58bd..e26d648 100644 --- a/algo/mutex.go +++ b/algo/mutex.go @@ -40,3 +40,9 @@ func (s *Sync) Reward(res *[]float64) { defer s.mu.Unlock() s.Reward(res) } + +func (s *Sync) Size() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.Size() +} diff --git a/algo/ucb.go b/algo/ucb.go index d4c0014..2136976 100644 --- a/algo/ucb.go +++ b/algo/ucb.go @@ -1,7 +1,6 @@ -package algo +package ucb import ( - "errors" "math" "tuxpa.in/a/gambit/helper" @@ -30,15 +29,8 @@ func (ucb *UCB) Select(r float64) int { } func (u *UCB) Update(a int, r float64) error { - if a < 0 || a >= len(u.cr.Rewards) || r < 0 { - return errors.New("TODO") - } - u.cr.Counts[a]++ - dec := float64(u.cr.Counts[a]) - u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec - return nil + return u.cr.Update(a, r) } - func (u *UCB) Reset(n int) error { u.cr.ResetTo(n) return nil diff --git a/gambit.go b/gambit.go index a6bfad3..4cb0f09 100644 --- a/gambit.go +++ b/gambit.go @@ -5,6 +5,8 @@ type Bandit interface { Update(a int, r float64) error Reset(n int) error - Count(res *[]int) - Reward(res *[]float64) + Size() int + + Count(res []int) + Reward(res []float64) } diff --git a/go.mod b/go.mod index 710d834..a187029 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module tuxpa.in/a/gambit go 1.19 + +require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df + +require ( + github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + golang.org/x/sys v0.1.0 // indirect + lukechampine.com/frand v1.4.2 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3fbc5a0 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw= +lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s= diff --git a/helper/storage.go b/helper/storage.go index f26748e..c0666b3 100644 --- a/helper/storage.go +++ b/helper/storage.go @@ -1,5 +1,12 @@ package helper +import ( + "errors" + "math" + + "golang.org/x/exp/constraints" +) + type CountReward struct { Counts []int Rewards []float64 @@ -22,42 +29,62 @@ func (c *CountReward) ResetTo(size int) { c.Rewards[idx] = 0 } } - -func (c *CountReward) Count(res *[]int) { - if res == nil { - r := make([]int, len(c.Counts)) - res = &r +func (c *CountReward) CountMax() (i int) { + xs := c.Counts + i = math.MinInt + for _, v := range xs { + if v > i { + i = v + } } - if len(c.Counts) < len(*res) { - *res = append(*res, len(c.Counts)-len(*res)) + return +} +func (c *CountReward) RewardMax() (i float64) { + xs := c.Rewards + i = math.Inf(-1) + for _, v := range xs { + if v > i { + i = v + } } - - (*res) = (*res)[:len(c.Counts)] - - copy(*res, c.Counts) + return } +func (c *CountReward) Update(a int, r float64) error { + if a < 0 || a >= c.Size() || r < 0 { + return errors.New("TODO") + } + c.Counts[a]++ + dec := float64(c.Counts[a]) + c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec + return nil +} func (c *CountReward) CountSum() (i int) { - for _, v := range c.Counts { - i = i + v - } - return i + return sum(c.Counts) } -func (c *CountReward) CountReward() (i int) { - for _, v := range c.Counts { - i = i + v - } - return i +func (c *CountReward) RewardSum() (i float64) { + return sum(c.Rewards) } -func (c *CountReward) Reward(res *[]float64) { - if res == nil { - r := make([]float64, len(c.Rewards)) - res = &r - } - if len(c.Rewards) < len(*res) { - *res = append(*res, make([]float64, len(c.Rewards)-len(*res))...) - } - (*res) = (*res)[:len(c.Rewards)] - copy(*res, c.Rewards) +func (c *CountReward) Size() int { + return len(c.Counts) +} + +func (c *CountReward) Count(res []int) { + copy(res, c.Counts) +} + +func (c *CountReward) Reward(res []float64) { + copy(res, c.Rewards) +} + +func sum[T Numeric](n []T) (i T) { + for _, v := range n { + i = i + v + } + return +} + +type Numeric interface { + constraints.Integer | constraints.Float }