ok
This commit is contained in:
commit
ef5261b577
54
algo/ucb.go
Normal file
54
algo/ucb.go
Normal file
@ -0,0 +1,54 @@
|
||||
package ucb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"tuxpa.in/a/gambit/helper"
|
||||
)
|
||||
|
||||
type UCB struct {
|
||||
cr helper.CountReward
|
||||
}
|
||||
|
||||
func (ucb *UCB) Select(r float64) int {
|
||||
a := len(ucb.cr.Counts)
|
||||
for _, v := range ucb.cr.Counts {
|
||||
if v == 0 {
|
||||
return a
|
||||
}
|
||||
}
|
||||
sz := len(ucb.cr.Counts)
|
||||
var res float64
|
||||
for _, v := range ucb.cr.Counts {
|
||||
ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v))
|
||||
if ans > res {
|
||||
res = ans
|
||||
}
|
||||
}
|
||||
return int(res)
|
||||
}
|
||||
|
||||
func (u *UCB) Update(a int, r float64) error {
|
||||
if a < 0 || a >= len(u.cr.Rewards) || r < 0 {
|
||||
return errors.New("TODO")
|
||||
}
|
||||
u.cr.Counts[a]++
|
||||
dec := float64(u.cr.Counts[a])
|
||||
u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UCB) Reset(n int) error {
|
||||
u.cr.ResetTo(n)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (u *UCB) Count(res *[]int) {
|
||||
u.cr.Count(res)
|
||||
}
|
||||
|
||||
func (u *UCB) Reward(res *[]float64) {
|
||||
u.cr.Reward(res)
|
||||
}
|
10
gambit.go
Normal file
10
gambit.go
Normal file
@ -0,0 +1,10 @@
|
||||
package gambit
|
||||
|
||||
type Bandit interface {
|
||||
Select(r float64) int
|
||||
Update(a int, r float64) error
|
||||
Reset(n int) error
|
||||
|
||||
Count(res *[]int)
|
||||
Reward(res *[]float64)
|
||||
}
|
42
helper/mutex.go
Normal file
42
helper/mutex.go
Normal file
@ -0,0 +1,42 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"tuxpa.in/a/gambit"
|
||||
)
|
||||
|
||||
type Sync struct {
|
||||
gambit.Bandit
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *Sync) Select(r float64) int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Select(r)
|
||||
}
|
||||
|
||||
func (s *Sync) Update(a int, r float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Update(a, r)
|
||||
}
|
||||
|
||||
func (s *Sync) Reset(n int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.Reset(n)
|
||||
}
|
||||
|
||||
func (s *Sync) Count(res *[]int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Count(res)
|
||||
}
|
||||
|
||||
func (s *Sync) Reward(res *[]float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Reward(res)
|
||||
}
|
43
helper/storage.go
Normal file
43
helper/storage.go
Normal file
@ -0,0 +1,43 @@
|
||||
package helper
|
||||
|
||||
type CountReward struct {
|
||||
Counts []int
|
||||
Rewards []float64
|
||||
}
|
||||
|
||||
func (c *CountReward) ResetTo(size int) {
|
||||
if len(c.Counts) > size {
|
||||
c.Counts = make([]int, size)
|
||||
}
|
||||
c.Counts = c.Counts[:size]
|
||||
|
||||
if len(c.Rewards) > size {
|
||||
c.Rewards = make([]float64, size)
|
||||
}
|
||||
c.Rewards = c.Rewards[:size]
|
||||
}
|
||||
|
||||
func (c *CountReward) Count(res *[]int) {
|
||||
if res == nil {
|
||||
r := make([]int, len(c.Counts))
|
||||
res = &r
|
||||
}
|
||||
if len(c.Counts) < len(*res) {
|
||||
*res = append(*res, len(c.Counts)-len(*res))
|
||||
}
|
||||
|
||||
(*res) = (*res)[:len(c.Counts)]
|
||||
copy(*res, c.Counts)
|
||||
}
|
||||
|
||||
func (c *CountReward) Reward(res *[]float64) {
|
||||
if res == nil {
|
||||
r := make([]float64, len(c.Rewards))
|
||||
res = &r
|
||||
}
|
||||
if len(c.Rewards) < len(*res) {
|
||||
*res = append(*res, make([]float64, len(c.Rewards)-len(*res))...)
|
||||
}
|
||||
(*res) = (*res)[:len(c.Rewards)]
|
||||
copy(*res, c.Rewards)
|
||||
}
|
Loading…
Reference in New Issue
Block a user