package ucb import ( "errors" "math" "tuxpa.in/a/gambit/helper" ) type UCB struct { cr helper.CountReward } func (ucb *UCB) Select(r float64) int { a := len(ucb.cr.Counts) for _, v := range ucb.cr.Counts { if v == 0 { return a } } sz := len(ucb.cr.Counts) var res float64 for idx, v := range ucb.cr.Counts { ans := math.Sqrt((2.0 * math.Log(float64(sz))) / float64(v)) if ans > res { res = float64(a) + ucb.cr.Rewards[idx] } } return int(res) } func (u *UCB) Update(a int, r float64) error { if a < 0 || a >= len(u.cr.Rewards) || r < 0 { return errors.New("TODO") } u.cr.Counts[a]++ dec := float64(u.cr.Counts[a]) u.cr.Rewards[a] = (u.cr.Rewards[a]*(dec-1) + r) / dec return nil } func (u *UCB) Reset(n int) error { u.cr.ResetTo(n) return nil } func (u *UCB) Count(res *[]int) { u.cr.Count(res) } func (u *UCB) Reward(res *[]float64) { u.cr.Reward(res) }