package helper import ( "errors" "math" "golang.org/x/exp/constraints" ) type CountReward struct { Counts []int Rewards []float64 } func (c *CountReward) ResetTo(size int) { if len(c.Counts) < size { c.Counts = make([]int, size) } c.Counts = c.Counts[:size] for idx := range c.Counts { c.Counts[idx] = 0 } if len(c.Rewards) < size { c.Rewards = make([]float64, size) } c.Rewards = c.Rewards[:size] for idx := range c.Rewards { c.Rewards[idx] = 0 } } func (c *CountReward) CountMax() (i int) { xs := c.Counts i = math.MinInt for _, v := range xs { if v > i { i = v } } return } func (c *CountReward) RewardMax() (i float64) { xs := c.Rewards i = math.Inf(-1) for _, v := range xs { if v > i { i = v } } return } func (c *CountReward) Update(a int, r float64) error { if a < 0 || a >= c.Size() || r < 0 { return errors.New("TODO") } c.Counts[a]++ dec := float64(c.Counts[a]) c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec return nil } func (c *CountReward) CountSum() (i int) { return sum(c.Counts) } func (c *CountReward) RewardSum() (i float64) { return sum(c.Rewards) } func (c *CountReward) Size() int { return len(c.Counts) } func (c *CountReward) Count(res []int) { copy(res, c.Counts) } func (c *CountReward) Reward(res []float64) { copy(res, c.Rewards) } func sum[T Numeric](n []T) (i T) { for _, v := range n { i = i + v } return } type Numeric interface { constraints.Integer | constraints.Float }