From 5f5c2362135ab6f8f071d4c44adb859595de4c87 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 6 Jul 2023 10:48:20 +0100 Subject: [PATCH] ok --- algo/epsilon_greedy.go | 3 +++ example/simple/main.go | 5 ++--- gang.go | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/algo/epsilon_greedy.go b/algo/epsilon_greedy.go index d265bd4..8b32f4d 100644 --- a/algo/epsilon_greedy.go +++ b/algo/epsilon_greedy.go @@ -2,9 +2,12 @@ package algo import ( "lukechampine.com/frand" + "tuxpa.in/a/gambit" "tuxpa.in/a/gambit/helper" ) +var _ gambit.Bandit = (*EpsilonGreedy)(nil) + type EpsilonGreedy struct { Epsilon float64 cr helper.CountReward diff --git a/example/simple/main.go b/example/simple/main.go index 86873c1..ac36e51 100644 --- a/example/simple/main.go +++ b/example/simple/main.go @@ -17,13 +17,12 @@ func main() { n := 100 for i := 0; i < n; i++ { - b.Update( + g.Observe( // select a random arm - b.Select(frand.Float64()), + frand.Float64(), // and supply a random score float64(frand.Intn(4)), ) } - log.Println(g.AllocateSolution()) } diff --git a/gang.go b/gang.go index 6003ede..6e2ed75 100644 --- a/gang.go +++ b/gang.go @@ -15,3 +15,8 @@ func (g *Gang) AllocateSolution() ([]int, []float64) { g.b.Reward(b) return a, b } + +func (g *Gang) Observe(score float64, reward float64) { + g.b.Update(g.b.Select(score), reward) + +}