This commit is contained in:
a 2023-07-05 23:22:48 +01:00
commit ef5261b577
5 changed files with 152 additions and 0 deletions

54
algo/ucb.go Normal file
View File

@ -0,0 +1,54 @@
package ucb
import (
"errors"
"math"
"tuxpa.in/a/gambit/helper"
)
type UCB struct {
cr helper.CountReward
}
func (ucb *UCB) Select(r float64) int {
a := len(ucb.cr.Counts)
for _, v := range ucb.cr.Counts {
if v == 0 {
return a
}
}
sz := len(ucb.cr.Counts)
var res float64
for _, v := range ucb.cr.Counts {
ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v))
if ans > res {
res = ans
}
}
return int(res)
}
func (u *UCB) Update(a int, r float64) error {
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
return errors.New("TODO")
}
u.cr.Counts[a]++
dec := float64(u.cr.Counts[a])
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
return nil
}
func (u *UCB) Reset(n int) error {
u.cr.ResetTo(n)
return nil
}
func (u *UCB) Count(res *[]int) {
u.cr.Count(res)
}
func (u *UCB) Reward(res *[]float64) {
u.cr.Reward(res)
}

10
gambit.go Normal file
View File

@ -0,0 +1,10 @@
package gambit
type Bandit interface {
Select(r float64) int
Update(a int, r float64) error
Reset(n int) error
Count(res *[]int)
Reward(res *[]float64)
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module tuxpa.in/a/gambit
go 1.19

42
helper/mutex.go Normal file
View File

@ -0,0 +1,42 @@
package helper
import (
"sync"
"tuxpa.in/a/gambit"
)
type Sync struct {
gambit.Bandit
mu sync.RWMutex
}
func (s *Sync) Select(r float64) int {
s.mu.Lock()
defer s.mu.Unlock()
return s.Select(r)
}
func (s *Sync) Update(a int, r float64) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.Update(a, r)
}
func (s *Sync) Reset(n int) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.Reset(n)
}
func (s *Sync) Count(res *[]int) {
s.mu.Lock()
defer s.mu.Unlock()
s.Count(res)
}
func (s *Sync) Reward(res *[]float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.Reward(res)
}

43
helper/storage.go Normal file
View File

@ -0,0 +1,43 @@
package helper
type CountReward struct {
Counts []int
Rewards []float64
}
func (c *CountReward) ResetTo(size int) {
if len(c.Counts) > size {
c.Counts = make([]int, size)
}
c.Counts = c.Counts[:size]
if len(c.Rewards) > size {
c.Rewards = make([]float64, size)
}
c.Rewards = c.Rewards[:size]
}
func (c *CountReward) Count(res *[]int) {
if res == nil {
r := make([]int, len(c.Counts))
res = &r
}
if len(c.Counts) < len(*res) {
*res = append(*res, len(c.Counts)-len(*res))
}
(*res) = (*res)[:len(c.Counts)]
copy(*res, c.Counts)
}
func (c *CountReward) Reward(res *[]float64) {
if res == nil {
r := make([]float64, len(c.Rewards))
res = &r
}
if len(c.Rewards) < len(*res) {
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
}
(*res) = (*res)[:len(c.Rewards)]
copy(*res, c.Rewards)
}