diff --git a/.travis.yml b/.travis.yml index 96da367..46129d7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,4 +7,4 @@ matrix: allow_failures: - go: tip script: - go test -v -race -cpu=1,2,4 ./... + go test -v -race -cpu=1,2,4 -bench . -benchmem ./... diff --git a/ctx.go b/ctx.go index 6157e29..04d8a11 100644 --- a/ctx.go +++ b/ctx.go @@ -5,7 +5,12 @@ import ( "io/ioutil" ) -var disabledLogger = New(ioutil.Discard).Level(Disabled) +var disabledLogger *Logger + +func init() { + l := New(ioutil.Discard).Level(Disabled) + disabledLogger = &l +} type ctxKey struct{} @@ -24,14 +29,18 @@ func (l Logger) WithContext(ctx context.Context) context.Context { *lp = l return ctx } + if l.level == Disabled { + // Do not store disabled logger. + return ctx + } return context.WithValue(ctx, ctxKey{}, &l) } // Ctx returns the Logger associated with the ctx. If no logger // is associated, a disabled logger is returned. -func Ctx(ctx context.Context) Logger { +func Ctx(ctx context.Context) *Logger { if l, ok := ctx.Value(ctxKey{}).(*Logger); ok { - return *l + return l } return disabledLogger } diff --git a/ctx_test.go b/ctx_test.go index b5a05da..942b723 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -11,7 +11,7 @@ func TestCtx(t *testing.T) { log := New(ioutil.Discard) ctx := log.WithContext(context.Background()) log2 := Ctx(ctx) - if !reflect.DeepEqual(log, log2) { + if !reflect.DeepEqual(log, *log2) { t.Error("Ctx did not return the expected logger") } @@ -19,12 +19,29 @@ func TestCtx(t *testing.T) { log = log.Level(InfoLevel) ctx = log.WithContext(ctx) log2 = Ctx(ctx) - if !reflect.DeepEqual(log, log2) { + if !reflect.DeepEqual(log, *log2) { t.Error("Ctx did not return the expected logger") } log2 = Ctx(context.Background()) - if !reflect.DeepEqual(log2, disabledLogger) { + if log2 != disabledLogger { t.Error("Ctx did not return the expected logger") } } + +func TestCtxDisabled(t *testing.T) { + ctx := disabledLogger.WithContext(context.Background()) + if ctx != context.Background() { + t.Error("WithContext stored a disabled logger") + } + + ctx = New(ioutil.Discard).WithContext(ctx) + if reflect.DeepEqual(Ctx(ctx), disabledLogger) { + t.Error("WithContext did not store logger") + } + + ctx = disabledLogger.WithContext(ctx) + if !reflect.DeepEqual(Ctx(ctx), disabledLogger) { + t.Error("WithContext did not update logger pointer with disabled logger") + } +} diff --git a/hlog/hlog.go b/hlog/hlog.go index 084ae45..bbe8b12 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -15,7 +15,7 @@ import ( // FromRequest gets the logger in the request's context. // This is a shortcut for log.Ctx(r.Context()) -func FromRequest(r *http.Request) zerolog.Logger { +func FromRequest(r *http.Request) *zerolog.Logger { return log.Ctx(r.Context()) } @@ -23,7 +23,10 @@ func FromRequest(r *http.Request) zerolog.Logger { func NewHandler(log zerolog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(log.WithContext(r.Context())) + // Create a copy of the logger (including internal context slice) + // to prevent data race when using UpdateContext. + l := log.With().Logger() + r = r.WithContext(l.WithContext(r.Context())) next.ServeHTTP(w, r) }) } @@ -35,8 +38,9 @@ func URLHandler(fieldKey string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, r.URL.String()).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, r.URL.String()) + }) next.ServeHTTP(w, r) }) } @@ -48,8 +52,9 @@ func MethodHandler(fieldKey string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, r.Method).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, r.Method) + }) next.ServeHTTP(w, r) }) } @@ -61,8 +66,9 @@ func RequestHandler(fieldKey string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, r.Method+" "+r.URL.String()).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, r.Method+" "+r.URL.String()) + }) next.ServeHTTP(w, r) }) } @@ -75,8 +81,9 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, host).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, host) + }) } next.ServeHTTP(w, r) }) @@ -90,8 +97,9 @@ func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ua := r.Header.Get("User-Agent"); ua != "" { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, ua).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, ua) + }) } next.ServeHTTP(w, r) }) @@ -105,8 +113,9 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ref := r.Header.Get("Referer"); ref != "" { log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, ref).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, ref) + }) } next.ServeHTTP(w, r) }) @@ -136,16 +145,18 @@ func IDFromRequest(r *http.Request) (id xid.ID, ok bool) { func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() id, ok := IDFromRequest(r) if !ok { id = xid.New() - ctx := context.WithValue(r.Context(), idKey{}, id) + ctx = context.WithValue(ctx, idKey{}, id) r = r.WithContext(ctx) } if fieldKey != "" { - log := zerolog.Ctx(r.Context()) - log = log.With().Str(fieldKey, id.String()).Logger() - r = r.WithContext(log.WithContext(r.Context())) + log := zerolog.Ctx(ctx) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, id.String()) + }) } if headerName != "" { w.Header().Set(headerName, id.String()) diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index 2654ecc..5ac960a 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -5,6 +5,7 @@ package hlog import ( "bytes" "fmt" + "io/ioutil" "net/http" "net/url" "testing" @@ -23,7 +24,7 @@ func TestNewHandler(t *testing.T) { lh := NewHandler(log) h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) - if !reflect.DeepEqual(l, log) { + if !reflect.DeepEqual(*l, log) { t.Fail() } })) @@ -38,12 +39,12 @@ func TestURLHandler(t *testing.T) { h := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"url":"/path?foo=bar"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"url":"/path?foo=bar"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestMethodHandler(t *testing.T) { @@ -54,12 +55,12 @@ func TestMethodHandler(t *testing.T) { h := MethodHandler("method")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"method":"POST"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"method":"POST"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestRequestHandler(t *testing.T) { @@ -71,12 +72,12 @@ func TestRequestHandler(t *testing.T) { h := RequestHandler("request")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"request":"POST /path?foo=bar"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"request":"POST /path?foo=bar"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestRemoteAddrHandler(t *testing.T) { @@ -87,12 +88,12 @@ func TestRemoteAddrHandler(t *testing.T) { h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"ip":"1.2.3.4"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"ip":"1.2.3.4"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestRemoteAddrHandlerIPv6(t *testing.T) { @@ -103,12 +104,12 @@ func TestRemoteAddrHandlerIPv6(t *testing.T) { h := RemoteAddrHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestUserAgentHandler(t *testing.T) { @@ -121,12 +122,12 @@ func TestUserAgentHandler(t *testing.T) { h := UserAgentHandler("ua")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"ua":"some user agent string"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"ua":"some user agent string"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestRefererHandler(t *testing.T) { @@ -139,12 +140,12 @@ func TestRefererHandler(t *testing.T) { h := RefererHandler("referer")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") - if want, got := `{"referer":"http://foo.com/bar"}`+"\n", out.String(); want != got { - t.Errorf("Invalid log output, got: %s, want: %s", got, want) - } })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) + if want, got := `{"referer":"http://foo.com/bar"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } } func TestRequestIDHandler(t *testing.T) { @@ -171,3 +172,66 @@ func TestRequestIDHandler(t *testing.T) { h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(httptest.NewRecorder(), r) } + +func TestCombinedHandlers(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, + } + h := MethodHandler("method")(RequestHandler("request")(URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })))) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"method":"POST","request":"POST /path?foo=bar","url":"/path?foo=bar"}`+"\n", out.String(); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func BenchmarkHandlers(b *testing.B) { + r := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/path", RawQuery: "foo=bar"}, + } + h1 := URLHandler("url")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h2 := MethodHandler("method")(RequestHandler("request")(h1)) + handlers := map[string]http.Handler{ + "Single": NewHandler(zerolog.New(ioutil.Discard))(h1), + "Combined": NewHandler(zerolog.New(ioutil.Discard))(h2), + "SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1), + "CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2), + } + for name := range handlers { + h := handlers[name] + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + h.ServeHTTP(nil, r) + } + }) + } +} + +func BenchmarkDataRace(b *testing.B) { + log := zerolog.New(nil).With(). + Str("foo", "bar"). + Logger() + lh := NewHandler(log) + h := lh(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("bar", "baz") + }) + l.Log().Msg("") + })) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + h.ServeHTTP(nil, &http.Request{}) + } + }) +} diff --git a/log.go b/log.go index 5b23a16..426882b 100644 --- a/log.go +++ b/log.go @@ -174,6 +174,20 @@ func (l Logger) With() Context { return Context{l} } +// UpdateContext updates the internal logger's context. +// +// Use this method with caution. If unsure, prefer the With method. +func (l *Logger) UpdateContext(update func(c Context) Context) { + if l == disabledLogger { + return + } + if cap(l.context) == 0 { + l.context = make([]byte, 1, 500) // first byte is timestamp flag + } + c := update(Context{*l}) + l.context = c.l.context +} + // Level creates a child logger with the minimum accepted level set to level. func (l Logger) Level(lvl Level) Logger { return Logger{ diff --git a/log/log.go b/log/log.go index 27c46d0..1036d85 100644 --- a/log/log.go +++ b/log/log.go @@ -86,6 +86,6 @@ func Log() *zerolog.Event { // Ctx returns the Logger associated with the ctx. If no logger // is associated, a disabled logger is returned. -func Ctx(ctx context.Context) zerolog.Logger { +func Ctx(ctx context.Context) *zerolog.Logger { return zerolog.Ctx(ctx) }