This commit is contained in:
a 2023-07-06 10:10:29 +01:00
parent b486de5d8c
commit 2f1636a33c
5 changed files with 55 additions and 7 deletions

View File

@ -1,4 +1,4 @@
package ucb
package algo
import (
"math"
@ -37,10 +37,10 @@ func (u *UCB) Reset(n int) error {
}
func (u *UCB) Count(res *[]int) {
func (u *UCB) Count(res []int) {
u.cr.Count(res)
}
func (u *UCB) Reward(res *[]float64) {
func (u *UCB) Reward(res []float64) {
u.cr.Reward(res)
}

29
example/simple/main.go Normal file
View File

@ -0,0 +1,29 @@
package main
import (
"log"
"lukechampine.com/frand"
"tuxpa.in/a/gambit"
"tuxpa.in/a/gambit/algo"
)
func main() {
g := &gambit.Gang{}
b := &algo.EpsilonGreedy{Epsilon: 0.1}
b.Reset(4)
g.WithBandit(b)
n := 100
for i := 0; i < n; i++ {
b.Update(
// select a random arm
b.Select(frand.Float64()),
// and supply a random score
float64(frand.Intn(4)),
)
}
log.Println(g.AllocateSolution())
}

17
gang.go Normal file
View File

@ -0,0 +1,17 @@
package gambit
type Gang struct {
b Bandit
}
func (g *Gang) WithBandit(b Bandit) {
g.b = b
}
func (g *Gang) AllocateSolution() ([]int, []float64) {
a := make([]int, g.b.Size())
b := make([]float64, g.b.Size())
g.b.Count(a)
g.b.Reward(b)
return a, b
}

6
go.mod
View File

@ -2,10 +2,12 @@ module tuxpa.in/a/gambit
go 1.19
require golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
require (
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
lukechampine.com/frand v1.4.2
)
require (
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect
golang.org/x/sys v0.1.0 // indirect
lukechampine.com/frand v1.4.2 // indirect
)

View File

@ -13,7 +13,7 @@ type CountReward struct {
}
func (c *CountReward) ResetTo(size int) {
if len(c.Counts) > size {
if len(c.Counts) < size {
c.Counts = make([]int, size)
}
c.Counts = c.Counts[:size]
@ -21,7 +21,7 @@ func (c *CountReward) ResetTo(size int) {
c.Counts[idx] = 0
}
if len(c.Rewards) > size {
if len(c.Rewards) < size {
c.Rewards = make([]float64, size)
}
c.Rewards = c.Rewards[:size]