add epislion greedy

This commit is contained in:
a 2023-07-06 09:59:17 +01:00
parent 2e4f0e9a8d
commit b486de5d8c
7 changed files with 125 additions and 42 deletions

39
algo/epsilon_greedy.go Normal file
View File

@ -0,0 +1,39 @@
package algo
import (
"lukechampine.com/frand"
"tuxpa.in/a/gambit/helper"
)
type EpsilonGreedy struct {
Epsilon float64
cr helper.CountReward
}
func (u *EpsilonGreedy) Select(r float64) int {
if r > u.Epsilon {
return int(u.cr.RewardMax())
}
return frand.Intn(u.cr.Size())
}
func (u *EpsilonGreedy) Update(a int, r float64) error {
return u.cr.Update(a, r)
}
func (u *EpsilonGreedy) Reset(n int) error {
u.cr.ResetTo(n)
return nil
}
func (u *EpsilonGreedy) Size() int {
return u.cr.Size()
}
func (u *EpsilonGreedy) Count(res []int) {
u.cr.Count(res)
}
func (u *EpsilonGreedy) Reward(res []float64) {
u.cr.Reward(res)
}

View File

@ -40,3 +40,9 @@ func (s *Sync) Reward(res *[]float64) {
defer s.mu.Unlock()
s.Reward(res)
}
func (s *Sync) Size() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.Size()
}

View File

@ -1,7 +1,6 @@
package algo
package ucb
import (
"errors"
"math"
"tuxpa.in/a/gambit/helper"
@ -30,15 +29,8 @@ func (ucb *UCB) Select(r float64) int {
}
func (u *UCB) Update(a int, r float64) error {
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
return errors.New("TODO")
}
u.cr.Counts[a]++
dec := float64(u.cr.Counts[a])
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
return nil
return u.cr.Update(a, r)
}
func (u *UCB) Reset(n int) error {
u.cr.ResetTo(n)
return nil

View File

@ -5,6 +5,8 @@ type Bandit interface {
Update(a int, r float64) error
Reset(n int) error
Count(res *[]int)
Reward(res *[]float64)
Size() int
Count(res []int)
Reward(res []float64)
}

8
go.mod
View File

@ -1,3 +1,11 @@
module tuxpa.in/a/gambit
go 1.19
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
require (
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
golang.org/x/sys v0.1.0 // indirect
lukechampine.com/frand v1.4.2 // indirect
)

9
go.sum Normal file
View File

@ -0,0 +1,9 @@
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME=
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw=
lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s=

View File

@ -1,5 +1,12 @@
package helper
import (
"errors"
"math"
"golang.org/x/exp/constraints"
)
type CountReward struct {
Counts []int
Rewards []float64
@ -22,42 +29,62 @@ func (c *CountReward) ResetTo(size int) {
c.Rewards[idx] = 0
}
}
func (c *CountReward) Count(res *[]int) {
if res == nil {
r := make([]int, len(c.Counts))
res = &r
func (c *CountReward) CountMax() (i int) {
xs := c.Counts
i = math.MinInt
for _, v := range xs {
if v > i {
i = v
}
}
if len(c.Counts) < len(*res) {
*res = append(*res, len(c.Counts)-len(*res))
return
}
func (c *CountReward) RewardMax() (i float64) {
xs := c.Rewards
i = math.Inf(-1)
for _, v := range xs {
if v > i {
i = v
}
}
(*res) = (*res)[:len(c.Counts)]
copy(*res, c.Counts)
return
}
func (c *CountReward) Update(a int, r float64) error {
if a < 0 || a >= c.Size() || r < 0 {
return errors.New("TODO")
}
c.Counts[a]++
dec := float64(c.Counts[a])
c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec
return nil
}
func (c *CountReward) CountSum() (i int) {
for _, v := range c.Counts {
i = i + v
}
return i
return sum(c.Counts)
}
func (c *CountReward) CountReward() (i int) {
for _, v := range c.Counts {
i = i + v
}
return i
func (c *CountReward) RewardSum() (i float64) {
return sum(c.Rewards)
}
func (c *CountReward) Reward(res *[]float64) {
if res == nil {
r := make([]float64, len(c.Rewards))
res = &r
}
if len(c.Rewards) < len(*res) {
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
}
(*res) = (*res)[:len(c.Rewards)]
copy(*res, c.Rewards)
func (c *CountReward) Size() int {
return len(c.Counts)
}
func (c *CountReward) Count(res []int) {
copy(res, c.Counts)
}
func (c *CountReward) Reward(res []float64) {
copy(res, c.Rewards)
}
func sum[T Numeric](n []T) (i T) {
for _, v := range n {
i = i + v
}
return
}
type Numeric interface {
constraints.Integer | constraints.Float
}