Pull request: 2471 defer

Merge in DNS/adguard-home from 2471-defer to master

Updates #2471.

* commit '1c754788f9139ed9741cf01c6d94bcced6909b8c':
  home: improve getCurrentUser
  home: improve checkSession
  Use a couple of defer in internal/home/auth.go
This commit is contained in:
Ainar Garipov 2020-12-23 13:16:44 +03:00
commit e829e7a064
3 changed files with 46 additions and 34 deletions

View File

@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
// Auth - global object // Auth - global object
type Auth struct { type Auth struct {
db *bbolt.DB db *bbolt.DB
sessions map[string]*session // session name -> session data sessions map[string]*session
lock sync.Mutex
users []User users []User
sessionTTL uint32 // in seconds lock sync.Mutex
sessionTTL uint32
} }
// User object // User object
@ -223,24 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
log.Debug("Auth: removed session from DB") log.Debug("Auth: removed session from DB")
} }
// CheckSession - check if session is valid // checkSessionResult is the result of checking a session.
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired type checkSessionResult int
func (a *Auth) CheckSession(sess string) int {
// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = -1
checkSessionExpired checkSessionResult = 1
)
// checkSession checks if the session is valid.
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
now := uint32(time.Now().UTC().Unix()) now := uint32(time.Now().UTC().Unix())
update := false update := false
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[sess] s, ok := a.sessions[sess]
if !ok { if !ok {
a.lock.Unlock() return checkSessionNotFound
return -1
} }
if s.expire <= now { if s.expire <= now {
delete(a.sessions, sess) delete(a.sessions, sess)
key, _ := hex.DecodeString(sess) key, _ := hex.DecodeString(sess)
a.removeSession(key) a.removeSession(key)
a.lock.Unlock()
return 1 return checkSessionExpired
} }
newExpire := now + a.sessionTTL newExpire := now + a.sessionTTL
@ -250,8 +261,6 @@ func (a *Auth) CheckSession(sess string) int {
s.expire = newExpire s.expire = newExpire
} }
a.lock.Unlock()
if update { if update {
key, _ := hex.DecodeString(sess) key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) { if a.storeSession(key, s) {
@ -259,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
} }
} }
return 0 return checkSessionOK
} }
// RemoveSession - remove session // RemoveSession - remove session
@ -392,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
ok = true ok = true
} else if err == nil { } else if err == nil {
r := Context.auth.CheckSession(cookie.Value) r := Context.auth.checkSession(cookie.Value)
if r == 0 { if r == checkSessionOK {
ok = true ok = true
} else if r < 0 { } else if r < 0 {
log.Debug("Auth: invalid cookie value: %s", cookie) log.Debug("Auth: invalid cookie value: %s", cookie)
@ -434,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
authRequired := Context.auth != nil && Context.auth.AuthRequired() authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil { if authRequired && err == nil {
r := Context.auth.CheckSession(cookie.Value) r := Context.auth.checkSession(cookie.Value)
if r == 0 { if r == checkSessionOK {
w.Header().Set("Location", "/") w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
return return
} else if r < 0 { } else if r == checkSessionNotFound {
log.Debug("Auth: invalid cookie value: %s", cookie) log.Debug("Auth: invalid cookie value: %s", cookie)
} }
} }
@ -503,32 +513,34 @@ func (a *Auth) UserFind(login, password string) User {
return User{} return User{}
} }
// GetCurrentUser - get the current user // getCurrentUser returns the current user. It returns an empty User if the
func (a *Auth) GetCurrentUser(r *http.Request) User { // user is not found.
func (a *Auth) getCurrentUser(r *http.Request) User {
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if err != nil { if err != nil {
// there's no Cookie, check Basic authentication // There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth() user, pass, ok := r.BasicAuth()
if ok { if ok {
u := Context.auth.UserFind(user, pass) return Context.auth.UserFind(user, pass)
return u
} }
return User{} return User{}
} }
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[cookie.Value] s, ok := a.sessions[cookie.Value]
if !ok { if !ok {
a.lock.Unlock()
return User{} return User{}
} }
for _, u := range a.users { for _, u := range a.users {
if u.Name == s.userName { if u.Name == s.userName {
a.lock.Unlock()
return u return u
} }
} }
a.lock.Unlock()
return User{} return User{}
} }

View File

@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
user := User{Name: "name"} user := User{Name: "name"}
a.UserAdd(&user, "password") a.UserAdd(&user, "password")
assert.True(t, a.CheckSession("notfound") == -1) assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.RemoveSession("notfound") a.RemoveSession("notfound")
sess, err := getSession(&users[0]) sess, err := getSession(&users[0])
@ -49,13 +49,13 @@ func TestAuth(t *testing.T) {
// check expiration // check expiration
s.expire = uint32(now) s.expire = uint32(now)
a.addSession(sess, &s) a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 1) assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
// add session with TTL = 2 sec // add session with TTL = 2 sec
s = session{} s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2) s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s) a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 0) assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
a.Close() a.Close()
@ -63,8 +63,8 @@ func TestAuth(t *testing.T) {
a = InitAuth(fn, users, 60) a = InitAuth(fn, users, 60)
// the session is still alive // the session is still alive
assert.True(t, a.CheckSession(sessStr) == 0) assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because CheckSession() has just updated it // reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2) s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s) a.storeSession(sess, &s)
a.Close() a.Close()
@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {
// load and remove expired sessions // load and remove expired sessions
a = InitAuth(fn, users, 60) a = InitAuth(fn, users, 60)
assert.True(t, a.CheckSession(sessStr) == -1) assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
a.Close() a.Close()
os.Remove(fn) os.Remove(fn)
@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
Context.auth = InitAuth(fn, users, 60) Context.auth = InitAuth(fn, users, 60)
handlerCalled := false handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(_ http.ResponseWriter, _ *http.Request) {
handlerCalled = true handlerCalled = true
} }
handler2 := optionalAuth(handler) handler2 := optionalAuth(handler)

View File

@ -89,7 +89,7 @@ type profileJSON struct {
func handleGetProfile(w http.ResponseWriter, r *http.Request) { func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{} pj := profileJSON{}
u := Context.auth.GetCurrentUser(r) u := Context.auth.getCurrentUser(r)
pj.Name = u.Name pj.Name = u.Name
data, err := json.Marshal(pj) data, err := json.Marshal(pj)