diff --git a/AGHTechDoc.md b/AGHTechDoc.md
index adf483d0..d901a862 100644
--- a/AGHTechDoc.md
+++ b/AGHTechDoc.md
@@ -54,6 +54,7 @@ Contents:
* Log-in page
* API: Log in
* API: Log out
+ * API: Get current user info
## Relations between subsystems
@@ -1207,7 +1208,7 @@ YAML configuration:
Session DB file:
- session="..." expire=123456
+ session="..." user=name expire=123456
...
Session data is SHA(random()+name+password).
@@ -1270,3 +1271,20 @@ Response:
302 Found
Location: /login.html
Set-Cookie: session=...; Expires=Thu, 01 Jan 1970 00:00:00 GMT
+
+
+### API: Get current user info
+
+Request:
+
+ GET /control/profile
+
+Response:
+
+ 200 OK
+
+ {
+ "name":"..."
+ }
+
+If no client is configured then authentication is disabled and server sends an empty response.
diff --git a/client/src/actions/index.js b/client/src/actions/index.js
index 3087c47d..28c2a713 100644
--- a/client/src/actions/index.js
+++ b/client/src/actions/index.js
@@ -213,6 +213,21 @@ export const getClients = () => async (dispatch) => {
}
};
+export const getProfileRequest = createAction('GET_PROFILE_REQUEST');
+export const getProfileFailure = createAction('GET_PROFILE_FAILURE');
+export const getProfileSuccess = createAction('GET_PROFILE_SUCCESS');
+
+export const getProfile = () => async (dispatch) => {
+ dispatch(getProfileRequest());
+ try {
+ const profile = await apiClient.getProfile();
+ dispatch(getProfileSuccess(profile));
+ } catch (error) {
+ dispatch(addErrorToast({ error }));
+ dispatch(getProfileFailure());
+ }
+};
+
export const dnsStatusRequest = createAction('DNS_STATUS_REQUEST');
export const dnsStatusFailure = createAction('DNS_STATUS_FAILURE');
export const dnsStatusSuccess = createAction('DNS_STATUS_SUCCESS');
@@ -224,6 +239,7 @@ export const getDnsStatus = () => async (dispatch) => {
dispatch(dnsStatusSuccess(dnsStatus));
dispatch(getVersion());
dispatch(getTlsStatus());
+ dispatch(getProfile());
} catch (error) {
dispatch(addErrorToast({ error }));
dispatch(dnsStatusFailure());
diff --git a/client/src/api/Api.js b/client/src/api/Api.js
index c5ced2b8..470577a8 100644
--- a/client/src/api/Api.js
+++ b/client/src/api/Api.js
@@ -525,6 +525,14 @@ class Api {
};
return this.makeRequest(path, method, config);
}
+
+ // Profile
+ GET_PROFILE = { path: 'profile', method: 'GET' };
+
+ getProfile() {
+ const { path, method } = this.GET_PROFILE;
+ return this.makeRequest(path, method);
+ }
}
const apiClient = new Api();
diff --git a/client/src/components/Header/index.js b/client/src/components/Header/index.js
index 28fa0767..8d16e614 100644
--- a/client/src/components/Header/index.js
+++ b/client/src/components/Header/index.js
@@ -60,9 +60,11 @@ class Header extends Component {
/>
diff --git a/client/src/reducers/index.js b/client/src/reducers/index.js
index 589da42e..0e8ff407 100644
--- a/client/src/reducers/index.js
+++ b/client/src/reducers/index.js
@@ -189,6 +189,14 @@ const dashboard = handleActions(
processingDnsSettings: false,
};
},
+
+ [actions.getProfileRequest]: state => ({ ...state, processingProfile: true }),
+ [actions.getProfileFailure]: state => ({ ...state, processingProfile: false }),
+ [actions.getProfileSuccess]: (state, { payload }) => ({
+ ...state,
+ name: payload.name,
+ processingProfile: false,
+ }),
},
{
processing: true,
@@ -198,6 +206,7 @@ const dashboard = handleActions(
processingClients: true,
processingUpdate: false,
processingDnsSettings: true,
+ processingProfile: true,
upstreamDns: '',
bootstrapDns: '',
allServers: false,
@@ -209,6 +218,7 @@ const dashboard = handleActions(
dnsVersion: '',
clients: [],
autoClients: [],
+ name: '',
},
);
diff --git a/home/auth.go b/home/auth.go
index 52b62e70..98f2ccae 100644
--- a/home/auth.go
+++ b/home/auth.go
@@ -20,10 +20,44 @@ import (
const cookieTTL = 365 * 24 // in hours
const expireTime = 30 * 24 // in hours
+type session struct {
+ userName string
+ expire uint32 // expiration time (in seconds)
+}
+
+/*
+expire byte[4]
+name_len byte[2]
+name byte[]
+*/
+func (s *session) serialize() []byte {
+ var data []byte
+ data = make([]byte, 4+2+len(s.userName))
+ binary.BigEndian.PutUint32(data[0:4], s.expire)
+ binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
+ copy(data[6:], []byte(s.userName))
+ return data
+}
+
+func (s *session) deserialize(data []byte) bool {
+ if len(data) < 4+2 {
+ return false
+ }
+ s.expire = binary.BigEndian.Uint32(data[0:4])
+ nameLen := binary.BigEndian.Uint16(data[4:6])
+ data = data[6:]
+
+ if len(data) < int(nameLen) {
+ return false
+ }
+ s.userName = string(data)
+ return true
+}
+
// Auth - global object
type Auth struct {
db *bbolt.DB
- sessions map[string]uint32 // session -> expiration time (in seconds)
+ sessions map[string]*session // session name -> session data
lock sync.Mutex
users []User
}
@@ -37,7 +71,7 @@ type User struct {
// InitAuth - create a global object
func InitAuth(dbFilename string, users []User) *Auth {
a := Auth{}
- a.sessions = make(map[string]uint32)
+ a.sessions = make(map[string]*session)
rand.Seed(time.Now().UTC().Unix())
var err error
a.db, err = bbolt.Open(dbFilename, 0644, nil)
@@ -56,6 +90,10 @@ func (a *Auth) Close() {
_ = a.db.Close()
}
+func bucketName() []byte {
+ return []byte("sessions-2")
+}
+
// load sessions from file, remove expired sessions
func (a *Auth) loadSessions() {
tx, err := a.db.Begin(true)
@@ -67,16 +105,22 @@ func (a *Auth) loadSessions() {
_ = tx.Rollback()
}()
- bkt := tx.Bucket([]byte("sessions"))
+ bkt := tx.Bucket(bucketName())
if bkt == nil {
return
}
removed := 0
+
+ if tx.Bucket([]byte("sessions")) != nil {
+ _ = tx.DeleteBucket([]byte("sessions"))
+ removed = 1
+ }
+
now := uint32(time.Now().UTC().Unix())
forEach := func(k, v []byte) error {
- i := binary.BigEndian.Uint32(v)
- if i <= now {
+ s := session{}
+ if !s.deserialize(v) || s.expire <= now {
err = bkt.Delete(k)
if err != nil {
log.Error("Auth: bbolt.Delete: %s", err)
@@ -85,7 +129,8 @@ func (a *Auth) loadSessions() {
}
return nil
}
- a.sessions[hex.EncodeToString(k)] = i
+
+ a.sessions[hex.EncodeToString(k)] = &s
return nil
}
_ = bkt.ForEach(forEach)
@@ -99,11 +144,15 @@ func (a *Auth) loadSessions() {
}
// store session data in file
-func (a *Auth) storeSession(data []byte, expire uint32) {
+func (a *Auth) addSession(data []byte, s *session) {
a.lock.Lock()
- a.sessions[hex.EncodeToString(data)] = expire
+ a.sessions[hex.EncodeToString(data)] = s
a.lock.Unlock()
+ a.storeSession(data, s)
+}
+// store session data in file
+func (a *Auth) storeSession(data []byte, s *session) {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("Auth: bbolt.Begin: %s", err)
@@ -113,15 +162,12 @@ func (a *Auth) storeSession(data []byte, expire uint32) {
_ = tx.Rollback()
}()
- bkt, err := tx.CreateBucketIfNotExists([]byte("sessions"))
+ bkt, err := tx.CreateBucketIfNotExists(bucketName())
if err != nil {
log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err)
return
}
- var val []byte
- val = make([]byte, 4)
- binary.BigEndian.PutUint32(val, expire)
- err = bkt.Put(data, val)
+ err = bkt.Put(data, s.serialize())
if err != nil {
log.Error("Auth: bbolt.Put: %s", err)
return
@@ -147,7 +193,7 @@ func (a *Auth) removeSession(sess []byte) {
_ = tx.Rollback()
}()
- bkt := tx.Bucket([]byte("sessions"))
+ bkt := tx.Bucket(bucketName())
if bkt == nil {
log.Error("Auth: bbolt.Bucket")
return
@@ -174,12 +220,12 @@ func (a *Auth) CheckSession(sess string) int {
update := false
a.lock.Lock()
- expire, ok := a.sessions[sess]
+ s, ok := a.sessions[sess]
if !ok {
a.lock.Unlock()
return -1
}
- if expire <= now {
+ if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSession(key)
@@ -188,17 +234,17 @@ func (a *Auth) CheckSession(sess string) int {
}
newExpire := now + expireTime*60*60
- if expire/(24*60*60) != newExpire/(24*60*60) {
+ if s.expire/(24*60*60) != newExpire/(24*60*60) {
// update expiration time once a day
update = true
- a.sessions[sess] = newExpire
+ s.expire = newExpire
}
a.lock.Unlock()
if update {
key, _ := hex.DecodeString(sess)
- a.storeSession(key, expire)
+ a.storeSession(key, s)
}
return 0
@@ -238,8 +284,10 @@ func httpCookie(req loginJSON) string {
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
expstr += "GMT"
- expireSess := uint32(now.Unix()) + expireTime*60*60
- config.auth.storeSession(sess, expireSess)
+ s := session{}
+ s.userName = u.Name
+ s.expire = uint32(now.Unix()) + expireTime*60*60
+ config.auth.addSession(sess, &s)
return fmt.Sprintf("session=%s; Path=/; HttpOnly; Expires=%s", hex.EncodeToString(sess), expstr)
}
@@ -402,6 +450,34 @@ func (a *Auth) UserFind(login string, password string) User {
return User{}
}
+// GetCurrentUser - get the current user
+func (a *Auth) GetCurrentUser(r *http.Request) User {
+ cookie, err := r.Cookie("session")
+ if err != nil {
+ // there's no Cookie, check Basic authentication
+ user, pass, ok := r.BasicAuth()
+ if ok {
+ u := config.auth.UserFind(user, pass)
+ return u
+ }
+ }
+
+ a.lock.Lock()
+ s, ok := a.sessions[cookie.Value]
+ if !ok {
+ a.lock.Unlock()
+ return User{}
+ }
+ for _, u := range a.users {
+ if u.Name == s.userName {
+ a.lock.Unlock()
+ return u
+ }
+ }
+ a.lock.Unlock()
+ return User{}
+}
+
// GetUsers - get users
func (a *Auth) GetUsers() []User {
a.lock.Lock()
diff --git a/home/auth_test.go b/home/auth_test.go
index 2ae532fd..ed2c3e6a 100644
--- a/home/auth_test.go
+++ b/home/auth_test.go
@@ -28,6 +28,7 @@ func TestAuth(t *testing.T) {
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
a := InitAuth(fn, nil)
+ s := session{}
user := User{Name: "name"}
a.UserAdd(&user, "password")
@@ -38,12 +39,16 @@ func TestAuth(t *testing.T) {
sess := getSession(&users[0])
sessStr := hex.EncodeToString(sess)
+ now := time.Now().UTC().Unix()
// check expiration
- a.storeSession(sess, uint32(time.Now().UTC().Unix()))
+ s.expire = uint32(now)
+ a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 1)
// add session with TTL = 2 sec
- a.storeSession(sess, uint32(time.Now().UTC().Unix()+2))
+ s = session{}
+ s.expire = uint32(now + 2)
+ a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 0)
a.Close()
@@ -53,6 +58,9 @@ func TestAuth(t *testing.T) {
// the session is still alive
assert.True(t, a.CheckSession(sessStr) == 0)
+ // reset our expiration time because CheckSession() has just updated it
+ s.expire = uint32(now + 2)
+ a.storeSession(sess, &s)
a.Close()
u := a.UserFind("name", "password")
diff --git a/home/control.go b/home/control.go
index 1f2eb1fa..143f73fc 100644
--- a/home/control.go
+++ b/home/control.go
@@ -377,6 +377,23 @@ func checkDNS(input string, bootstrap []string) error {
return nil
}
+type profileJSON struct {
+ Name string `json:"name"`
+}
+
+func handleGetProfile(w http.ResponseWriter, r *http.Request) {
+ pj := profileJSON{}
+ u := config.auth.GetCurrentUser(r)
+ pj.Name = u.Name
+
+ data, err := json.Marshal(pj)
+ if err != nil {
+ httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
+ return
+ }
+ _, _ = w.Write(data)
+}
+
// --------------
// DNS-over-HTTPS
// --------------
@@ -416,6 +433,7 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/access/list", handleAccessList)
httpRegister(http.MethodPost, "/control/access/set", handleAccessSet)
+ httpRegister("GET", "/control/profile", handleGetProfile)
RegisterFilteringHandlers()
RegisterTLSHandlers()
diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md
index 7dc74883..281d58e5 100644
--- a/openapi/CHANGELOG.md
+++ b/openapi/CHANGELOG.md
@@ -1,6 +1,23 @@
# AdGuard Home API Change Log
+## v0.99.1: API changes
+
+### API: Get current user info: GET /control/profile
+
+Request:
+
+ GET /control/profile
+
+Response:
+
+ 200 OK
+
+ {
+ "name":"..."
+ }
+
+
## v0.99: incompatible API changes
* A note about web user authentication
diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml
index 3f1474cb..eca972b3 100644
--- a/openapi/openapi.yaml
+++ b/openapi/openapi.yaml
@@ -970,6 +970,18 @@ paths:
302:
description: OK
+ /profile:
+ get:
+ tags:
+ - global
+ operationId: getProfile
+ summary: ""
+ responses:
+ 200:
+ description: OK
+ schema:
+ $ref: "#/definitions/ProfileInfo"
+
definitions:
ServerStatus:
type: "object"
@@ -1559,6 +1571,14 @@ definitions:
description: "Network interfaces dictionary (key is the interface name)"
additionalProperties:
$ref: "#/definitions/NetInterface"
+
+ ProfileInfo:
+ type: "object"
+ description: "Information about the current user"
+ properties:
+ name:
+ type: "string"
+
Client:
type: "object"
description: "Client information"
diff --git a/stats/stats_unit.go b/stats/stats_unit.go
index 1d524bcc..3db14d5b 100644
--- a/stats/stats_unit.go
+++ b/stats/stats_unit.go
@@ -346,7 +346,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
return nil
}
- log.Tracef("Loading unit %d", id)
+ // log.Tracef("Loading unit %d", id)
var buf bytes.Buffer
buf.Write(bkt.Get([]byte{0}))