From 7eb3e00b351e41445a97bbf246134fdef84f0ce8 Mon Sep 17 00:00:00 2001 From: jvoisin Date: Mon, 21 Dec 2020 19:39:39 +0100 Subject: [PATCH 1/3] Use a couple of defer in internal/home/auth.go --- internal/home/auth.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/internal/home/auth.go b/internal/home/auth.go index 00407fa0..941c97ab 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -230,16 +230,15 @@ func (a *Auth) CheckSession(sess string) int { update := false a.lock.Lock() + defer a.lock.Unlock() s, ok := a.sessions[sess] if !ok { - a.lock.Unlock() return -1 } if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) - a.lock.Unlock() return 1 } @@ -250,8 +249,6 @@ func (a *Auth) CheckSession(sess string) int { s.expire = newExpire } - a.lock.Unlock() - if update { key, _ := hex.DecodeString(sess) if a.storeSession(key, s) { @@ -517,18 +514,16 @@ func (a *Auth) GetCurrentUser(r *http.Request) User { } a.lock.Lock() + defer a.lock.Unlock() s, ok := a.sessions[cookie.Value] if !ok { - a.lock.Unlock() return User{} } for _, u := range a.users { if u.Name == s.userName { - a.lock.Unlock() return u } } - a.lock.Unlock() return User{} } From 925c5df801dc940239930034cdee3b683304b860 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 22 Dec 2020 21:05:12 +0300 Subject: [PATCH 2/3] home: improve checkSession --- internal/home/auth.go | 41 +++++++++++++++++++++++++------------- internal/home/auth_test.go | 14 ++++++------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/internal/home/auth.go b/internal/home/auth.go index 941c97ab..01f89a26 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool { // Auth - global object type Auth struct { db *bbolt.DB - sessions map[string]*session // session name -> session data - lock sync.Mutex + sessions map[string]*session users []User - sessionTTL uint32 // in seconds + lock sync.Mutex + sessionTTL uint32 } // User object @@ -223,23 +223,35 @@ func (a *Auth) removeSession(sess []byte) { log.Debug("Auth: removed session from DB") } -// CheckSession - check if session is valid -// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired -func (a *Auth) CheckSession(sess string) int { +// checkSessionResult is the result of checking a session. +type checkSessionResult int + +// checkSessionResult constants. +const ( + checkSessionOK checkSessionResult = 0 + checkSessionNotFound checkSessionResult = -1 + checkSessionExpired checkSessionResult = 1 +) + +// checkSession checks if the session is valid. +func (a *Auth) checkSession(sess string) (res checkSessionResult) { now := uint32(time.Now().UTC().Unix()) update := false a.lock.Lock() defer a.lock.Unlock() + s, ok := a.sessions[sess] if !ok { - return -1 + return checkSessionNotFound } + if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) - return 1 + + return checkSessionExpired } newExpire := now + a.sessionTTL @@ -256,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int { } } - return 0 + return checkSessionOK } // RemoveSession - remove session @@ -389,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool) ok = true } else if err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { ok = true } else if r < 0 { log.Debug("Auth: invalid cookie value: %s", cookie) @@ -431,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re authRequired := Context.auth != nil && Context.auth.AuthRequired() cookie, err := r.Cookie(sessionCookieName) if authRequired && err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { w.Header().Set("Location", "/") w.WriteHeader(http.StatusFound) + return - } else if r < 0 { + } else if r == checkSessionNotFound { log.Debug("Auth: invalid cookie value: %s", cookie) } } diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 25db2dd6..0998a2a6 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -38,7 +38,7 @@ func TestAuth(t *testing.T) { user := User{Name: "name"} a.UserAdd(&user, "password") - assert.True(t, a.CheckSession("notfound") == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession("notfound")) a.RemoveSession("notfound") sess, err := getSession(&users[0]) @@ -49,13 +49,13 @@ func TestAuth(t *testing.T) { // check expiration s.expire = uint32(now) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 1) + assert.Equal(t, checkSessionExpired, a.checkSession(sessStr)) // add session with TTL = 2 sec s = session{} s.expire = uint32(time.Now().UTC().Unix() + 2) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 0) + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) a.Close() @@ -63,8 +63,8 @@ func TestAuth(t *testing.T) { a = InitAuth(fn, users, 60) // the session is still alive - assert.True(t, a.CheckSession(sessStr) == 0) - // reset our expiration time because CheckSession() has just updated it + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) + // reset our expiration time because checkSession() has just updated it s.expire = uint32(time.Now().UTC().Unix() + 2) a.storeSession(sess, &s) a.Close() @@ -76,7 +76,7 @@ func TestAuth(t *testing.T) { // load and remove expired sessions a = InitAuth(fn, users, 60) - assert.True(t, a.CheckSession(sessStr) == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr)) a.Close() os.Remove(fn) @@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) { Context.auth = InitAuth(fn, users, 60) handlerCalled := false - handler := func(w http.ResponseWriter, r *http.Request) { + handler := func(_ http.ResponseWriter, _ *http.Request) { handlerCalled = true } handler2 := optionalAuth(handler) From 1c754788f9139ed9741cf01c6d94bcced6909b8c Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 22 Dec 2020 21:09:53 +0300 Subject: [PATCH 3/3] home: improve getCurrentUser --- internal/home/auth.go | 14 +++++++++----- internal/home/control.go | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/home/auth.go b/internal/home/auth.go index 01f89a26..dce17241 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -513,30 +513,34 @@ func (a *Auth) UserFind(login, password string) User { return User{} } -// GetCurrentUser - get the current user -func (a *Auth) GetCurrentUser(r *http.Request) User { +// getCurrentUser returns the current user. It returns an empty User if the +// user is not found. +func (a *Auth) getCurrentUser(r *http.Request) User { cookie, err := r.Cookie(sessionCookieName) if err != nil { - // there's no Cookie, check Basic authentication + // There's no Cookie, check Basic authentication. user, pass, ok := r.BasicAuth() if ok { - u := Context.auth.UserFind(user, pass) - return u + return Context.auth.UserFind(user, pass) } + return User{} } a.lock.Lock() defer a.lock.Unlock() + s, ok := a.sessions[cookie.Value] if !ok { return User{} } + for _, u := range a.users { if u.Name == s.userName { return u } } + return User{} } diff --git a/internal/home/control.go b/internal/home/control.go index 3443515a..616557a8 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -89,7 +89,7 @@ type profileJSON struct { func handleGetProfile(w http.ResponseWriter, r *http.Request) { pj := profileJSON{} - u := Context.auth.GetCurrentUser(r) + u := Context.auth.getCurrentUser(r) pj.Name = u.Name data, err := json.Marshal(pj)