110 lines
2.6 KiB
Go
110 lines
2.6 KiB
Go
package capability
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
|
|
"tuxpa.in/a/irc/pkg/ircv3"
|
|
)
|
|
|
|
var registrarKey struct{}
|
|
|
|
type Capability interface {
|
|
Handle(w ircv3.MessageWriter, e *ircv3.Event) bool
|
|
Middleware(next ircv3.Handler) ircv3.Handler
|
|
}
|
|
|
|
type Registrar struct {
|
|
enabled []Capability
|
|
|
|
discovered []string
|
|
negotiating int
|
|
}
|
|
|
|
func (cs *Registrar) Enable(c Capability) error {
|
|
cs.enabled = append(cs.enabled, c)
|
|
return nil
|
|
}
|
|
|
|
func (cs *Registrar) NewMiddleware() func(next ircv3.Handler) ircv3.Handler {
|
|
waiting := map[Capability]struct{}{}
|
|
for _, v := range cs.enabled {
|
|
waiting[v] = struct{}{}
|
|
}
|
|
return func(next ircv3.Handler) ircv3.Handler {
|
|
cur := next
|
|
for _, v := range cs.enabled {
|
|
cur = v.Middleware(next)
|
|
}
|
|
return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) {
|
|
e = e.WithContext(context.WithValue(e.Context(), registrarKey, *cs))
|
|
// reset signal, set negotiating to 1
|
|
if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" {
|
|
for k := range waiting {
|
|
delete(waiting, k)
|
|
}
|
|
for _, v := range cs.enabled {
|
|
waiting[v] = struct{}{}
|
|
}
|
|
cs.negotiating = 1
|
|
cs.discovered = nil
|
|
}
|
|
// done negotiating, so run the handle middleware
|
|
if cs.negotiating == 4 {
|
|
cur.Handle(w, e)
|
|
return
|
|
}
|
|
// increase negotiating stage when receive the CAP * LS response
|
|
if cs.negotiating == 2 && e.Msg.Command == "CAP" && e.Msg.Param(0) == "*" && e.Msg.Param(1) == "LS" {
|
|
for _, v := range strings.Fields(e.Msg.Param(2)) {
|
|
cs.discovered = append(cs.discovered, v)
|
|
}
|
|
cs.negotiating = 3
|
|
}
|
|
if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" {
|
|
w.WriteMessage(ircv3.NewMessage("CAP", "LS", "302"))
|
|
cs.negotiating = 2
|
|
}
|
|
// run all negotiation handlers
|
|
// this allows sasl auth to happen before CAP LS happens.
|
|
for v := range waiting {
|
|
ready := v.Handle(w, e)
|
|
if ready {
|
|
delete(waiting, v)
|
|
}
|
|
}
|
|
// not done negotiating yet, so dont run handler middleware
|
|
next.Handle(w, e)
|
|
|
|
if cs.negotiating == 3 && len(waiting) == 0 {
|
|
cs.negotiating = 4
|
|
w.WriteMessage(ircv3.NewMessage("CAP", "END"))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func NegotiatingState(ctx context.Context) int {
|
|
val, ok := ctx.Value(registrarKey).(*Registrar)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return val.negotiating
|
|
}
|
|
func IsDoneNegotiating(ctx context.Context) bool {
|
|
return NegotiatingState(ctx) >= 3
|
|
}
|
|
|
|
func HasCapability(ctx context.Context, c string) bool {
|
|
val, ok := ctx.Value(registrarKey).(*Registrar)
|
|
if !ok {
|
|
return false
|
|
}
|
|
for _, v := range val.discovered {
|
|
if v == c {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|