add epislion greedy
This commit is contained in:
parent
2e4f0e9a8d
commit
b486de5d8c
|
@ -0,0 +1,39 @@
|
||||||
|
package algo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"lukechampine.com/frand"
|
||||||
|
"tuxpa.in/a/gambit/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EpsilonGreedy struct {
|
||||||
|
Epsilon float64
|
||||||
|
cr helper.CountReward
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Select(r float64) int {
|
||||||
|
if r > u.Epsilon {
|
||||||
|
return int(u.cr.RewardMax())
|
||||||
|
}
|
||||||
|
return frand.Intn(u.cr.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Update(a int, r float64) error {
|
||||||
|
return u.cr.Update(a, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Reset(n int) error {
|
||||||
|
u.cr.ResetTo(n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Size() int {
|
||||||
|
return u.cr.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Count(res []int) {
|
||||||
|
u.cr.Count(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *EpsilonGreedy) Reward(res []float64) {
|
||||||
|
u.cr.Reward(res)
|
||||||
|
}
|
|
@ -40,3 +40,9 @@ func (s *Sync) Reward(res *[]float64) {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.Reward(res)
|
s.Reward(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sync) Size() int {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.Size()
|
||||||
|
}
|
||||||
|
|
12
algo/ucb.go
12
algo/ucb.go
|
@ -1,7 +1,6 @@
|
||||||
package algo
|
package ucb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"tuxpa.in/a/gambit/helper"
|
"tuxpa.in/a/gambit/helper"
|
||||||
|
@ -30,15 +29,8 @@ func (ucb *UCB) Select(r float64) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UCB) Update(a int, r float64) error {
|
func (u *UCB) Update(a int, r float64) error {
|
||||||
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
|
return u.cr.Update(a, r)
|
||||||
return errors.New("TODO")
|
|
||||||
}
|
|
||||||
u.cr.Counts[a]++
|
|
||||||
dec := float64(u.cr.Counts[a])
|
|
||||||
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UCB) Reset(n int) error {
|
func (u *UCB) Reset(n int) error {
|
||||||
u.cr.ResetTo(n)
|
u.cr.ResetTo(n)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -5,6 +5,8 @@ type Bandit interface {
|
||||||
Update(a int, r float64) error
|
Update(a int, r float64) error
|
||||||
Reset(n int) error
|
Reset(n int) error
|
||||||
|
|
||||||
Count(res *[]int)
|
Size() int
|
||||||
Reward(res *[]float64)
|
|
||||||
|
Count(res []int)
|
||||||
|
Reward(res []float64)
|
||||||
}
|
}
|
||||||
|
|
8
go.mod
8
go.mod
|
@ -1,3 +1,11 @@
|
||||||
module tuxpa.in/a/gambit
|
module tuxpa.in/a/gambit
|
||||||
|
|
||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
|
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
|
||||||
|
golang.org/x/sys v0.1.0 // indirect
|
||||||
|
lukechampine.com/frand v1.4.2 // indirect
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||||
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
|
||||||
|
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME=
|
||||||
|
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
|
||||||
|
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw=
|
||||||
|
lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s=
|
|
@ -1,5 +1,12 @@
|
||||||
package helper
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
type CountReward struct {
|
type CountReward struct {
|
||||||
Counts []int
|
Counts []int
|
||||||
Rewards []float64
|
Rewards []float64
|
||||||
|
@ -22,42 +29,62 @@ func (c *CountReward) ResetTo(size int) {
|
||||||
c.Rewards[idx] = 0
|
c.Rewards[idx] = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (c *CountReward) CountMax() (i int) {
|
||||||
func (c *CountReward) Count(res *[]int) {
|
xs := c.Counts
|
||||||
if res == nil {
|
i = math.MinInt
|
||||||
r := make([]int, len(c.Counts))
|
for _, v := range xs {
|
||||||
res = &r
|
if v > i {
|
||||||
|
i = v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if len(c.Counts) < len(*res) {
|
return
|
||||||
*res = append(*res, len(c.Counts)-len(*res))
|
}
|
||||||
|
func (c *CountReward) RewardMax() (i float64) {
|
||||||
|
xs := c.Rewards
|
||||||
|
i = math.Inf(-1)
|
||||||
|
for _, v := range xs {
|
||||||
|
if v > i {
|
||||||
|
i = v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return
|
||||||
(*res) = (*res)[:len(c.Counts)]
|
|
||||||
|
|
||||||
copy(*res, c.Counts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CountReward) Update(a int, r float64) error {
|
||||||
|
if a < 0 || a >= c.Size() || r < 0 {
|
||||||
|
return errors.New("TODO")
|
||||||
|
}
|
||||||
|
c.Counts[a]++
|
||||||
|
dec := float64(c.Counts[a])
|
||||||
|
c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (c *CountReward) CountSum() (i int) {
|
func (c *CountReward) CountSum() (i int) {
|
||||||
for _, v := range c.Counts {
|
return sum(c.Counts)
|
||||||
i = i + v
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
}
|
||||||
func (c *CountReward) CountReward() (i int) {
|
func (c *CountReward) RewardSum() (i float64) {
|
||||||
for _, v := range c.Counts {
|
return sum(c.Rewards)
|
||||||
i = i + v
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CountReward) Reward(res *[]float64) {
|
func (c *CountReward) Size() int {
|
||||||
if res == nil {
|
return len(c.Counts)
|
||||||
r := make([]float64, len(c.Rewards))
|
}
|
||||||
res = &r
|
|
||||||
}
|
func (c *CountReward) Count(res []int) {
|
||||||
if len(c.Rewards) < len(*res) {
|
copy(res, c.Counts)
|
||||||
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
|
}
|
||||||
}
|
|
||||||
(*res) = (*res)[:len(c.Rewards)]
|
func (c *CountReward) Reward(res []float64) {
|
||||||
copy(*res, c.Rewards)
|
copy(res, c.Rewards)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sum[T Numeric](n []T) (i T) {
|
||||||
|
for _, v := range n {
|
||||||
|
i = i + v
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Numeric interface {
|
||||||
|
constraints.Integer | constraints.Float
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue