home: improve checkSession
This commit is contained in:
parent
7eb3e00b35
commit
925c5df801
@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
|
|||||||
// Auth - global object
|
// Auth - global object
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
db *bbolt.DB
|
db *bbolt.DB
|
||||||
sessions map[string]*session // session name -> session data
|
sessions map[string]*session
|
||||||
lock sync.Mutex
|
|
||||||
users []User
|
users []User
|
||||||
sessionTTL uint32 // in seconds
|
lock sync.Mutex
|
||||||
|
sessionTTL uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// User object
|
// User object
|
||||||
@ -223,23 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
|
|||||||
log.Debug("Auth: removed session from DB")
|
log.Debug("Auth: removed session from DB")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckSession - check if session is valid
|
// checkSessionResult is the result of checking a session.
|
||||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
type checkSessionResult int
|
||||||
func (a *Auth) CheckSession(sess string) int {
|
|
||||||
|
// checkSessionResult constants.
|
||||||
|
const (
|
||||||
|
checkSessionOK checkSessionResult = 0
|
||||||
|
checkSessionNotFound checkSessionResult = -1
|
||||||
|
checkSessionExpired checkSessionResult = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkSession checks if the session is valid.
|
||||||
|
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||||
now := uint32(time.Now().UTC().Unix())
|
now := uint32(time.Now().UTC().Unix())
|
||||||
update := false
|
update := false
|
||||||
|
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
defer a.lock.Unlock()
|
defer a.lock.Unlock()
|
||||||
|
|
||||||
s, ok := a.sessions[sess]
|
s, ok := a.sessions[sess]
|
||||||
if !ok {
|
if !ok {
|
||||||
return -1
|
return checkSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.expire <= now {
|
if s.expire <= now {
|
||||||
delete(a.sessions, sess)
|
delete(a.sessions, sess)
|
||||||
key, _ := hex.DecodeString(sess)
|
key, _ := hex.DecodeString(sess)
|
||||||
a.removeSession(key)
|
a.removeSession(key)
|
||||||
return 1
|
|
||||||
|
return checkSessionExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
newExpire := now + a.sessionTTL
|
newExpire := now + a.sessionTTL
|
||||||
@ -256,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return checkSessionOK
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSession - remove session
|
// RemoveSession - remove session
|
||||||
@ -389,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
|||||||
ok = true
|
ok = true
|
||||||
|
|
||||||
} else if err == nil {
|
} else if err == nil {
|
||||||
r := Context.auth.CheckSession(cookie.Value)
|
r := Context.auth.checkSession(cookie.Value)
|
||||||
if r == 0 {
|
if r == checkSessionOK {
|
||||||
ok = true
|
ok = true
|
||||||
} else if r < 0 {
|
} else if r < 0 {
|
||||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||||
@ -431,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
|||||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
if authRequired && err == nil {
|
if authRequired && err == nil {
|
||||||
r := Context.auth.CheckSession(cookie.Value)
|
r := Context.auth.checkSession(cookie.Value)
|
||||||
if r == 0 {
|
if r == checkSessionOK {
|
||||||
w.Header().Set("Location", "/")
|
w.Header().Set("Location", "/")
|
||||||
w.WriteHeader(http.StatusFound)
|
w.WriteHeader(http.StatusFound)
|
||||||
|
|
||||||
return
|
return
|
||||||
} else if r < 0 {
|
} else if r == checkSessionNotFound {
|
||||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
|
|||||||
user := User{Name: "name"}
|
user := User{Name: "name"}
|
||||||
a.UserAdd(&user, "password")
|
a.UserAdd(&user, "password")
|
||||||
|
|
||||||
assert.True(t, a.CheckSession("notfound") == -1)
|
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||||
a.RemoveSession("notfound")
|
a.RemoveSession("notfound")
|
||||||
|
|
||||||
sess, err := getSession(&users[0])
|
sess, err := getSession(&users[0])
|
||||||
@ -49,13 +49,13 @@ func TestAuth(t *testing.T) {
|
|||||||
// check expiration
|
// check expiration
|
||||||
s.expire = uint32(now)
|
s.expire = uint32(now)
|
||||||
a.addSession(sess, &s)
|
a.addSession(sess, &s)
|
||||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||||
|
|
||||||
// add session with TTL = 2 sec
|
// add session with TTL = 2 sec
|
||||||
s = session{}
|
s = session{}
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
a.addSession(sess, &s)
|
a.addSession(sess, &s)
|
||||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
|
|
||||||
a.Close()
|
a.Close()
|
||||||
|
|
||||||
@ -63,8 +63,8 @@ func TestAuth(t *testing.T) {
|
|||||||
a = InitAuth(fn, users, 60)
|
a = InitAuth(fn, users, 60)
|
||||||
|
|
||||||
// the session is still alive
|
// the session is still alive
|
||||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
// reset our expiration time because CheckSession() has just updated it
|
// reset our expiration time because checkSession() has just updated it
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
a.storeSession(sess, &s)
|
a.storeSession(sess, &s)
|
||||||
a.Close()
|
a.Close()
|
||||||
@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {
|
|||||||
|
|
||||||
// load and remove expired sessions
|
// load and remove expired sessions
|
||||||
a = InitAuth(fn, users, 60)
|
a = InitAuth(fn, users, 60)
|
||||||
assert.True(t, a.CheckSession(sessStr) == -1)
|
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||||
|
|
||||||
a.Close()
|
a.Close()
|
||||||
os.Remove(fn)
|
os.Remove(fn)
|
||||||
@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
|
|||||||
Context.auth = InitAuth(fn, users, 60)
|
Context.auth = InitAuth(fn, users, 60)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
handlerCalled = true
|
handlerCalled = true
|
||||||
}
|
}
|
||||||
handler2 := optionalAuth(handler)
|
handler2 := optionalAuth(handler)
|
||||||
|
Loading…
Reference in New Issue
Block a user