add epislion greedy
This commit is contained in:
parent
2e4f0e9a8d
commit
b486de5d8c
|
@ -0,0 +1,39 @@
|
|||
package algo
|
||||
|
||||
import (
|
||||
"lukechampine.com/frand"
|
||||
"tuxpa.in/a/gambit/helper"
|
||||
)
|
||||
|
||||
type EpsilonGreedy struct {
|
||||
Epsilon float64
|
||||
cr helper.CountReward
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Select(r float64) int {
|
||||
if r > u.Epsilon {
|
||||
return int(u.cr.RewardMax())
|
||||
}
|
||||
return frand.Intn(u.cr.Size())
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Update(a int, r float64) error {
|
||||
return u.cr.Update(a, r)
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Reset(n int) error {
|
||||
u.cr.ResetTo(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Size() int {
|
||||
return u.cr.Size()
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Count(res []int) {
|
||||
u.cr.Count(res)
|
||||
}
|
||||
|
||||
func (u *EpsilonGreedy) Reward(res []float64) {
|
||||
u.cr.Reward(res)
|
||||
}
|
|
@ -40,3 +40,9 @@ func (s *Sync) Reward(res *[]float64) {
|
|||
defer s.mu.Unlock()
|
||||
s.Reward(res)
|
||||
}
|
||||
|
||||
func (s *Sync) Size() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Size()
|
||||
}
|
||||
|
|
12
algo/ucb.go
12
algo/ucb.go
|
@ -1,7 +1,6 @@
|
|||
package algo
|
||||
package ucb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"tuxpa.in/a/gambit/helper"
|
||||
|
@ -30,15 +29,8 @@ func (ucb *UCB) Select(r float64) int {
|
|||
}
|
||||
|
||||
func (u *UCB) Update(a int, r float64) error {
|
||||
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
|
||||
return errors.New("TODO")
|
||||
}
|
||||
u.cr.Counts[a]++
|
||||
dec := float64(u.cr.Counts[a])
|
||||
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
|
||||
return nil
|
||||
return u.cr.Update(a, r)
|
||||
}
|
||||
|
||||
func (u *UCB) Reset(n int) error {
|
||||
u.cr.ResetTo(n)
|
||||
return nil
|
||||
|
|
|
@ -5,6 +5,8 @@ type Bandit interface {
|
|||
Update(a int, r float64) error
|
||||
Reset(n int) error
|
||||
|
||||
Count(res *[]int)
|
||||
Reward(res *[]float64)
|
||||
Size() int
|
||||
|
||||
Count(res []int)
|
||||
Reward(res []float64)
|
||||
}
|
||||
|
|
8
go.mod
8
go.mod
|
@ -1,3 +1,11 @@
|
|||
module tuxpa.in/a/gambit
|
||||
|
||||
go 1.19
|
||||
|
||||
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
||||
|
||||
require (
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
||||
golang.org/x/sys v0.1.0 // indirect
|
||||
lukechampine.com/frand v1.4.2 // indirect
|
||||
)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME=
|
||||
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw=
|
||||
lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s=
|
|
@ -1,5 +1,12 @@
|
|||
package helper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type CountReward struct {
|
||||
Counts []int
|
||||
Rewards []float64
|
||||
|
@ -22,42 +29,62 @@ func (c *CountReward) ResetTo(size int) {
|
|||
c.Rewards[idx] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CountReward) Count(res *[]int) {
|
||||
if res == nil {
|
||||
r := make([]int, len(c.Counts))
|
||||
res = &r
|
||||
func (c *CountReward) CountMax() (i int) {
|
||||
xs := c.Counts
|
||||
i = math.MinInt
|
||||
for _, v := range xs {
|
||||
if v > i {
|
||||
i = v
|
||||
}
|
||||
}
|
||||
if len(c.Counts) < len(*res) {
|
||||
*res = append(*res, len(c.Counts)-len(*res))
|
||||
return
|
||||
}
|
||||
func (c *CountReward) RewardMax() (i float64) {
|
||||
xs := c.Rewards
|
||||
i = math.Inf(-1)
|
||||
for _, v := range xs {
|
||||
if v > i {
|
||||
i = v
|
||||
}
|
||||
}
|
||||
|
||||
(*res) = (*res)[:len(c.Counts)]
|
||||
|
||||
copy(*res, c.Counts)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CountReward) Update(a int, r float64) error {
|
||||
if a < 0 || a >= c.Size() || r < 0 {
|
||||
return errors.New("TODO")
|
||||
}
|
||||
c.Counts[a]++
|
||||
dec := float64(c.Counts[a])
|
||||
c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec
|
||||
return nil
|
||||
}
|
||||
func (c *CountReward) CountSum() (i int) {
|
||||
for _, v := range c.Counts {
|
||||
i = i + v
|
||||
}
|
||||
return i
|
||||
return sum(c.Counts)
|
||||
}
|
||||
func (c *CountReward) CountReward() (i int) {
|
||||
for _, v := range c.Counts {
|
||||
i = i + v
|
||||
}
|
||||
return i
|
||||
func (c *CountReward) RewardSum() (i float64) {
|
||||
return sum(c.Rewards)
|
||||
}
|
||||
|
||||
func (c *CountReward) Reward(res *[]float64) {
|
||||
if res == nil {
|
||||
r := make([]float64, len(c.Rewards))
|
||||
res = &r
|
||||
}
|
||||
if len(c.Rewards) < len(*res) {
|
||||
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
|
||||
}
|
||||
(*res) = (*res)[:len(c.Rewards)]
|
||||
copy(*res, c.Rewards)
|
||||
func (c *CountReward) Size() int {
|
||||
return len(c.Counts)
|
||||
}
|
||||
|
||||
func (c *CountReward) Count(res []int) {
|
||||
copy(res, c.Counts)
|
||||
}
|
||||
|
||||
func (c *CountReward) Reward(res []float64) {
|
||||
copy(res, c.Rewards)
|
||||
}
|
||||
|
||||
func sum[T Numeric](n []T) (i T) {
|
||||
for _, v := range n {
|
||||
i = i + v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Numeric interface {
|
||||
constraints.Integer | constraints.Float
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue