package capability import ( "context" "strings" "tuxpa.in/a/irc/pkg/ircv3" ) var registrarKey struct{} type Capability interface { Handle(w ircv3.MessageWriter, e *ircv3.Event) bool Middleware(next ircv3.Handler) ircv3.Handler } type Registrar struct { enabled []Capability discovered []string negotiating int } func (cs *Registrar) Enable(c Capability) error { cs.enabled = append(cs.enabled, c) return nil } func (cs *Registrar) NewMiddleware() func(next ircv3.Handler) ircv3.Handler { waiting := map[Capability]struct{}{} for _, v := range cs.enabled { waiting[v] = struct{}{} } return func(next ircv3.Handler) ircv3.Handler { cur := next for _, v := range cs.enabled { cur = v.Middleware(next) } return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { e = e.WithContext(context.WithValue(e.Context(), registrarKey, *cs)) // reset signal, set negotiating to 1 if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { for k := range waiting { delete(waiting, k) } for _, v := range cs.enabled { waiting[v] = struct{}{} } cs.negotiating = 1 cs.discovered = nil } // done negotiating, so run the handle middleware if cs.negotiating == 4 { cur.Handle(w, e) return } // increase negotiating stage when receive the CAP * LS response if cs.negotiating == 2 && e.Msg.Command == "CAP" && e.Msg.Param(0) == "*" && e.Msg.Param(1) == "LS" { for _, v := range strings.Fields(e.Msg.Param(2)) { cs.discovered = append(cs.discovered, v) } cs.negotiating = 3 } if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { w.WriteMessage(ircv3.NewMessage("CAP", "LS", "302")) cs.negotiating = 2 } // run all negotiation handlers // this allows sasl auth to happen before CAP LS happens. for v := range waiting { ready := v.Handle(w, e) if ready { delete(waiting, v) } } // not done negotiating yet, so dont run handler middleware next.Handle(w, e) if cs.negotiating == 3 && len(waiting) == 0 { cs.negotiating = 4 w.WriteMessage(ircv3.NewMessage("CAP", "END")) } }) } } func NegotiatingState(ctx context.Context) int { val, ok := ctx.Value(registrarKey).(*Registrar) if !ok { return 0 } return val.negotiating } func IsDoneNegotiating(ctx context.Context) bool { return NegotiatingState(ctx) >= 3 } func HasCapability(ctx context.Context, c string) bool { val, ok := ctx.Value(registrarKey).(*Registrar) if !ok { return false } for _, v := range val.discovered { if v == c { return true } } return false }