Pull request: 2471 defer
Merge in DNS/adguard-home from 2471-defer to master Updates #2471. * commit '1c754788f9139ed9741cf01c6d94bcced6909b8c': home: improve getCurrentUser home: improve checkSession Use a couple of defer in internal/home/auth.go
This commit is contained in:
commit
e829e7a064
@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
|
|||||||
// Auth - global object
|
// Auth - global object
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
db *bbolt.DB
|
db *bbolt.DB
|
||||||
sessions map[string]*session // session name -> session data
|
sessions map[string]*session
|
||||||
lock sync.Mutex
|
|
||||||
users []User
|
users []User
|
||||||
sessionTTL uint32 // in seconds
|
lock sync.Mutex
|
||||||
|
sessionTTL uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// User object
|
// User object
|
||||||
@ -223,24 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
|
|||||||
log.Debug("Auth: removed session from DB")
|
log.Debug("Auth: removed session from DB")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckSession - check if session is valid
|
// checkSessionResult is the result of checking a session.
|
||||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
type checkSessionResult int
|
||||||
func (a *Auth) CheckSession(sess string) int {
|
|
||||||
|
// checkSessionResult constants.
|
||||||
|
const (
|
||||||
|
checkSessionOK checkSessionResult = 0
|
||||||
|
checkSessionNotFound checkSessionResult = -1
|
||||||
|
checkSessionExpired checkSessionResult = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkSession checks if the session is valid.
|
||||||
|
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||||
now := uint32(time.Now().UTC().Unix())
|
now := uint32(time.Now().UTC().Unix())
|
||||||
update := false
|
update := false
|
||||||
|
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
s, ok := a.sessions[sess]
|
s, ok := a.sessions[sess]
|
||||||
if !ok {
|
if !ok {
|
||||||
a.lock.Unlock()
|
return checkSessionNotFound
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.expire <= now {
|
if s.expire <= now {
|
||||||
delete(a.sessions, sess)
|
delete(a.sessions, sess)
|
||||||
key, _ := hex.DecodeString(sess)
|
key, _ := hex.DecodeString(sess)
|
||||||
a.removeSession(key)
|
a.removeSession(key)
|
||||||
a.lock.Unlock()
|
|
||||||
return 1
|
return checkSessionExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
newExpire := now + a.sessionTTL
|
newExpire := now + a.sessionTTL
|
||||||
@ -250,8 +261,6 @@ func (a *Auth) CheckSession(sess string) int {
|
|||||||
s.expire = newExpire
|
s.expire = newExpire
|
||||||
}
|
}
|
||||||
|
|
||||||
a.lock.Unlock()
|
|
||||||
|
|
||||||
if update {
|
if update {
|
||||||
key, _ := hex.DecodeString(sess)
|
key, _ := hex.DecodeString(sess)
|
||||||
if a.storeSession(key, s) {
|
if a.storeSession(key, s) {
|
||||||
@ -259,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return checkSessionOK
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSession - remove session
|
// RemoveSession - remove session
|
||||||
@ -392,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
|||||||
ok = true
|
ok = true
|
||||||
|
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
r := Context.auth.CheckSession(cookie.Value)
|
r := Context.auth.checkSession(cookie.Value)
|
||||||
if r == 0 {
|
if r == checkSessionOK {
|
||||||
ok = true
|
ok = true
|
||||||
} else if r < 0 {
|
} else if r < 0 {
|
||||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||||
@ -434,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
|||||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
if authRequired && err == nil {
|
if authRequired && err == nil {
|
||||||
r := Context.auth.CheckSession(cookie.Value)
|
r := Context.auth.checkSession(cookie.Value)
|
||||||
if r == 0 {
|
if r == checkSessionOK {
|
||||||
w.Header().Set("Location", "/")
|
w.Header().Set("Location", "/")
|
||||||
w.WriteHeader(http.StatusFound)
|
w.WriteHeader(http.StatusFound)
|
||||||
|
|
||||||
return
|
return
|
||||||
} else if r < 0 {
|
} else if r == checkSessionNotFound {
|
||||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -503,32 +513,34 @@ func (a *Auth) UserFind(login, password string) User {
|
|||||||
return User{}
|
return User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentUser - get the current user
|
// getCurrentUser returns the current user. It returns an empty User if the
|
||||||
func (a *Auth) GetCurrentUser(r *http.Request) User {
|
// user is not found.
|
||||||
|
func (a *Auth) getCurrentUser(r *http.Request) User {
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// there's no Cookie, check Basic authentication
|
// There's no Cookie, check Basic authentication.
|
||||||
user, pass, ok := r.BasicAuth()
|
user, pass, ok := r.BasicAuth()
|
||||||
if ok {
|
if ok {
|
||||||
u := Context.auth.UserFind(user, pass)
|
return Context.auth.UserFind(user, pass)
|
||||||
return u
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return User{}
|
return User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
s, ok := a.sessions[cookie.Value]
|
s, ok := a.sessions[cookie.Value]
|
||||||
if !ok {
|
if !ok {
|
||||||
a.lock.Unlock()
|
|
||||||
return User{}
|
return User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, u := range a.users {
|
for _, u := range a.users {
|
||||||
if u.Name == s.userName {
|
if u.Name == s.userName {
|
||||||
a.lock.Unlock()
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
a.lock.Unlock()
|
|
||||||
return User{}
|
return User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
|
|||||||
user := User{Name: "name"}
|
user := User{Name: "name"}
|
||||||
a.UserAdd(&user, "password")
|
a.UserAdd(&user, "password")
|
||||||
|
|
||||||
assert.True(t, a.CheckSession("notfound") == -1)
|
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||||
a.RemoveSession("notfound")
|
a.RemoveSession("notfound")
|
||||||
|
|
||||||
sess, err := getSession(&users[0])
|
sess, err := getSession(&users[0])
|
||||||
@ -49,13 +49,13 @@ func TestAuth(t *testing.T) {
|
|||||||
// check expiration
|
// check expiration
|
||||||
s.expire = uint32(now)
|
s.expire = uint32(now)
|
||||||
a.addSession(sess, &s)
|
a.addSession(sess, &s)
|
||||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||||
|
|
||||||
// add session with TTL = 2 sec
|
// add session with TTL = 2 sec
|
||||||
s = session{}
|
s = session{}
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
a.addSession(sess, &s)
|
a.addSession(sess, &s)
|
||||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
|
|
||||||
a.Close()
|
a.Close()
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ func TestAuth(t *testing.T) {
|
|||||||
a = InitAuth(fn, users, 60)
|
a = InitAuth(fn, users, 60)
|
||||||
|
|
||||||
// the session is still alive
|
// the session is still alive
|
||||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
// reset our expiration time because CheckSession() has just updated it
|
// reset our expiration time because checkSession() has just updated it
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
a.storeSession(sess, &s)
|
a.storeSession(sess, &s)
|
||||||
a.Close()
|
a.Close()
|
||||||
@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {
|
|||||||
|
|
||||||
// load and remove expired sessions
|
// load and remove expired sessions
|
||||||
a = InitAuth(fn, users, 60)
|
a = InitAuth(fn, users, 60)
|
||||||
assert.True(t, a.CheckSession(sessStr) == -1)
|
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||||
|
|
||||||
a.Close()
|
a.Close()
|
||||||
os.Remove(fn)
|
os.Remove(fn)
|
||||||
@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||||||
Context.auth = InitAuth(fn, users, 60)
|
Context.auth = InitAuth(fn, users, 60)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
handlerCalled = true
|
handlerCalled = true
|
||||||
}
|
}
|
||||||
handler2 := optionalAuth(handler)
|
handler2 := optionalAuth(handler)
|
||||||
|
@ -89,7 +89,7 @@ type profileJSON struct {
|
|||||||
|
|
||||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
pj := profileJSON{}
|
pj := profileJSON{}
|
||||||
u := Context.auth.GetCurrentUser(r)
|
u := Context.auth.getCurrentUser(r)
|
||||||
pj.Name = u.Name
|
pj.Name = u.Name
|
||||||
|
|
||||||
data, err := json.Marshal(pj)
|
data, err := json.Marshal(pj)
|
||||||
|
Loading…
Reference in New Issue
Block a user