gambit/helper/storage.go
2023-07-06 10:10:29 +01:00

91 lines
1.5 KiB
Go

package helper
import (
"errors"
"math"
"golang.org/x/exp/constraints"
)
type CountReward struct {
Counts []int
Rewards []float64
}
func (c *CountReward) ResetTo(size int) {
if len(c.Counts) < size {
c.Counts = make([]int, size)
}
c.Counts = c.Counts[:size]
for idx := range c.Counts {
c.Counts[idx] = 0
}
if len(c.Rewards) < size {
c.Rewards = make([]float64, size)
}
c.Rewards = c.Rewards[:size]
for idx := range c.Rewards {
c.Rewards[idx] = 0
}
}
func (c *CountReward) CountMax() (i int) {
xs := c.Counts
i = math.MinInt
for _, v := range xs {
if v > i {
i = v
}
}
return
}
func (c *CountReward) RewardMax() (i float64) {
xs := c.Rewards
i = math.Inf(-1)
for _, v := range xs {
if v > i {
i = v
}
}
return
}
func (c *CountReward) Update(a int, r float64) error {
if a < 0 || a >= c.Size() || r < 0 {
return errors.New("TODO")
}
c.Counts[a]++
dec := float64(c.Counts[a])
c.Rewards[a] = (c.Rewards[a]*(dec-1) + r) / dec
return nil
}
func (c *CountReward) CountSum() (i int) {
return sum(c.Counts)
}
func (c *CountReward) RewardSum() (i float64) {
return sum(c.Rewards)
}
func (c *CountReward) Size() int {
return len(c.Counts)
}
func (c *CountReward) Count(res []int) {
copy(res, c.Counts)
}
func (c *CountReward) Reward(res []float64) {
copy(res, c.Rewards)
}
func sum[T Numeric](n []T) (i T) {
for _, v := range n {
i = i + v
}
return
}
type Numeric interface {
constraints.Integer | constraints.Float
}