gateway: use dedicated auth context keys type

Use auth context keys dedicated type instead of strings and avoid code
duplication by moving shared code under a common package.
This commit is contained in:
Simone Gotti 2022-02-24 11:18:29 +01:00
parent a1997490e7
commit b7c797bde1
12 changed files with 132 additions and 127 deletions

View File

@ -17,44 +17,20 @@ package action
import (
"context"
"agola.io/agola/internal/services/common"
scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
cstypes "agola.io/agola/services/configstore/types"
errors "golang.org/x/xerrors"
)
func (h *ActionHandler) CurrentUserID(ctx context.Context) string {
userIDVal := ctx.Value("userid")
if userIDVal == nil {
return ""
}
return userIDVal.(string)
}
func (h *ActionHandler) IsUserLogged(ctx context.Context) bool {
return ctx.Value("userid") != nil
}
func (h *ActionHandler) IsUserAdmin(ctx context.Context) bool {
isAdmin := false
isAdminVal := ctx.Value("admin")
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
return isAdmin
}
func (h *ActionHandler) IsUserLoggedOrAdmin(ctx context.Context) bool {
return h.IsUserLogged(ctx) || h.IsUserAdmin(ctx)
}
func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx)
isAdmin := common.IsUserAdmin(ctx)
if isAdmin {
return true, nil
}
userID := h.CurrentUserID(ctx)
userID := common.CurrentUserID(ctx)
if userID == "" {
return false, nil
}
@ -77,12 +53,12 @@ func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, err
}
func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx)
isAdmin := common.IsUserAdmin(ctx)
if isAdmin {
return true, nil
}
userID := h.CurrentUserID(ctx)
userID := common.CurrentUserID(ctx)
if userID == "" {
return false, nil
}
@ -113,12 +89,12 @@ func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.Co
}
func (h *ActionHandler) IsProjectMember(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx)
isAdmin := common.IsUserAdmin(ctx)
if isAdmin {
return true, nil
}
userID := h.CurrentUserID(ctx)
userID := common.CurrentUserID(ctx)
if userID == "" {
return false, nil
}
@ -170,7 +146,7 @@ func (h *ActionHandler) IsVariableOwner(ctx context.Context, parentType cstypes.
}
func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, error) {
groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup)
groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup)
if err != nil {
return false, err
}
@ -179,7 +155,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
var ownerType cstypes.ConfigType
var ownerID string
switch groupType {
case common.GroupTypeProject:
case scommon.GroupTypeProject:
p, resp, err := h.configstoreClient.GetProject(ctx, groupID)
if err != nil {
return false, ErrFromRemote(resp, err)
@ -187,7 +163,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
ownerType = p.OwnerType
ownerID = p.OwnerID
visibility = p.GlobalVisibility
case common.GroupTypeUser:
case scommon.GroupTypeUser:
// user direct runs
ownerType = cstypes.ConfigTypeUser
ownerID = groupID
@ -208,7 +184,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
}
func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (bool, error) {
groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup)
groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup)
if err != nil {
return false, err
}
@ -216,14 +192,14 @@ func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (b
var ownerType cstypes.ConfigType
var ownerID string
switch groupType {
case common.GroupTypeProject:
case scommon.GroupTypeProject:
p, resp, err := h.configstoreClient.GetProject(ctx, groupID)
if err != nil {
return false, ErrFromRemote(resp, err)
}
ownerType = p.OwnerType
ownerID = p.OwnerID
case common.GroupTypeUser:
case scommon.GroupTypeUser:
// user direct runs
ownerType = cstypes.ConfigTypeUser
ownerID = groupID

View File

@ -17,6 +17,7 @@ package action
import (
"context"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types"
@ -87,7 +88,7 @@ type CreateOrgRequest struct {
}
func (h *ActionHandler) CreateOrg(ctx context.Context, req *CreateOrgRequest) (*cstypes.Organization, error) {
if !h.IsUserLoggedOrAdmin(ctx) {
if !common.IsUserLoggedOrAdmin(ctx) {
return nil, errors.Errorf("user not logged in")
}

View File

@ -22,6 +22,7 @@ import (
"path"
gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/services/types"
"agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types"
@ -61,7 +62,7 @@ type CreateProjectRequest struct {
}
func (h *ActionHandler) CreateProject(ctx context.Context, req *CreateProjectRequest) (*csapitypes.Project, error) {
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil {
@ -227,7 +228,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, projectRef string, re
}
func (h *ActionHandler) ProjectUpdateRepoLinkedAccount(ctx context.Context, projectRef string) (*csapitypes.Project, error) {
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil {
@ -428,7 +429,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er
}
func (h *ActionHandler) ProjectCreateRun(ctx context.Context, projectRef, branch, tag, refName, commitSHA string) error {
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil {

View File

@ -17,6 +17,7 @@ package action
import (
"context"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types"
@ -60,7 +61,7 @@ type CreateRemoteSourceRequest struct {
}
func (h *ActionHandler) CreateRemoteSource(ctx context.Context, req *CreateRemoteSourceRequest) (*cstypes.RemoteSource, error) {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin")
}
@ -134,7 +135,7 @@ type UpdateRemoteSourceRequest struct {
}
func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemoteSourceRequest) (*cstypes.RemoteSource, error) {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin")
}
@ -182,7 +183,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot
}
func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, rsRef string) error {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return errors.Errorf("user not admin")
}

View File

@ -24,7 +24,8 @@ import (
"agola.io/agola/internal/config"
gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/runconfig"
"agola.io/agola/internal/services/common"
scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
itypes "agola.io/agola/internal/services/types"
"agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types"
@ -268,7 +269,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
if !canDoRunAction {
return util.NewErrForbidden(errors.Errorf("user not authorized"))
}
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
if curUserID == "" {
return util.NewErrBadRequest(errors.Errorf("no logged in user"))
}
@ -285,7 +286,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
if rt.Annotations != nil {
annotations = rt.Annotations
}
approversAnnotation, ok := annotations[common.ApproversAnnotation]
approversAnnotation, ok := annotations[scommon.ApproversAnnotation]
if ok {
if err := json.Unmarshal([]byte(approversAnnotation), &approvers); err != nil {
return errors.Errorf("failed to unmarshal run task approvers annotation: %w", err)
@ -304,7 +305,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
return errors.Errorf("failed to marshal run task approvers annotation: %w", err)
}
annotations[common.ApproversAnnotation] = string(approversj)
annotations[scommon.ApproversAnnotation] = string(approversj)
rsreq := &rsapitypes.RunTaskActionsRequest{
ActionType: rsapitypes.RunTaskActionTypeSetAnnotations,
@ -372,32 +373,32 @@ func (h *ActionHandler) CreateRuns(ctx context.Context, req *CreateRunRequest) e
return util.NewErrBadRequest(errors.Errorf("empty message"))
}
var baseGroupType common.GroupType
var baseGroupType scommon.GroupType
var baseGroupID string
var groupType common.GroupType
var groupType scommon.GroupType
var group string
if req.RunType == itypes.RunTypeProject {
baseGroupType = common.GroupTypeProject
baseGroupType = scommon.GroupTypeProject
baseGroupID = req.Project.ID
} else {
baseGroupType = common.GroupTypeUser
baseGroupType = scommon.GroupTypeUser
baseGroupID = req.User.ID
}
switch req.RefType {
case itypes.RunRefTypeBranch:
groupType = common.GroupTypeBranch
groupType = scommon.GroupTypeBranch
group = req.Branch
case itypes.RunRefTypeTag:
groupType = common.GroupTypeTag
groupType = scommon.GroupTypeTag
group = req.Tag
case itypes.RunRefTypePullRequest:
groupType = common.GroupTypePullRequest
groupType = scommon.GroupTypePullRequest
group = req.PullRequestID
}
runGroup := common.GenRunGroup(baseGroupType, baseGroupID, groupType, group)
runGroup := scommon.GenRunGroup(baseGroupType, baseGroupID, groupType, group)
gitURL, err := util.ParseGitURL(req.CloneURL)
if err != nil {
@ -595,7 +596,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque
}
// remove overriden variables
pvars = common.FilterOverriddenVariables(pvars)
pvars = scommon.FilterOverriddenVariables(pvars)
// get project secrets
secrets, _, err := h.configstoreClient.GetProjectSecrets(ctx, req.Project.ID, true)
@ -611,7 +612,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque
continue
}
// get the secret value referenced by the variable, it must be a secret at the same level or a lower level
secret := common.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets)
secret := scommon.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets)
if secret != nil {
varValue, ok := secret.Data[varval.SecretVar]
if ok {

View File

@ -24,7 +24,8 @@ import (
gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/gitsources/agolagit"
"agola.io/agola/internal/services/common"
scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/services/types"
"agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types"
@ -46,7 +47,7 @@ func isAccessTokenExpired(expiresAt time.Time) bool {
}
func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.User, error) {
if !h.IsUserLoggedOrAdmin(ctx) {
if !common.IsUserLoggedOrAdmin(ctx) {
return nil, errors.Errorf("user not logged in")
}
@ -58,7 +59,7 @@ func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.U
}
func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*csapitypes.UserOrgsResponse, error) {
if !h.IsUserLogged(ctx) {
if !common.IsUserLogged(ctx) {
return nil, errors.Errorf("user not logged in")
}
@ -76,7 +77,7 @@ type GetUsersRequest struct {
}
func (h *ActionHandler) GetUsers(ctx context.Context, req *GetUsersRequest) ([]*cstypes.User, error) {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not logged in")
}
@ -92,7 +93,7 @@ type CreateUserRequest struct {
}
func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) (*cstypes.User, error) {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin")
}
@ -123,17 +124,8 @@ type CreateUserTokenRequest struct {
}
func (h *ActionHandler) CreateUserToken(ctx context.Context, req *CreateUserTokenRequest) (string, error) {
var userID string
userIDVal := ctx.Value("userid")
if userIDVal != nil {
userID = userIDVal.(string)
}
isAdmin := false
isAdminVal := ctx.Value("admin")
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
isAdmin := common.IsUserAdmin(ctx)
userID := common.CurrentUserID(ctx)
userRef := req.UserRef
user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
@ -193,11 +185,11 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userRef, rs.Name))
}
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil {
return nil, err
}
userSource, err := common.GetUserSource(rs, accessToken)
userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil {
return nil, err
}
@ -271,7 +263,7 @@ func (h *ActionHandler) RefreshLinkedAccount(ctx context.Context, rs *cstypes.Re
case cstypes.RemoteSourceAuthTypeOauth2:
// refresh access token if expired
if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) {
userSource, err := common.GetOauth2Source(rs, "")
userSource, err := scommon.GetOauth2Source(rs, "")
if err != nil {
return nil, err
}
@ -301,7 +293,7 @@ func (h *ActionHandler) GetGitSource(ctx context.Context, rs *cstypes.RemoteSour
if err != nil {
return nil, err
}
return common.GetGitSource(rs, la)
return scommon.GetGitSource(rs, la)
}
type RegisterUserRequest struct {
@ -329,11 +321,11 @@ func (h *ActionHandler) RegisterUser(ctx context.Context, req *RegisterUserReque
return nil, util.NewErrBadRequest(errors.Errorf("remote source user registration is disabled"))
}
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil {
return nil, err
}
userSource, err := common.GetUserSource(rs, accessToken)
userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil {
return nil, err
}
@ -391,11 +383,11 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (*
return nil, util.NewErrBadRequest(errors.Errorf("remote source user login is disabled"))
}
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil {
return nil, err
}
userSource, err := common.GetUserSource(rs, accessToken)
userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil {
return nil, err
}
@ -451,7 +443,7 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (*
}
// generate jwt token
token, err := common.GenerateLoginJWTToken(h.sd, user.ID)
token, err := scommon.GenerateLoginJWTToken(h.sd, user.ID)
if err != nil {
return nil, err
}
@ -480,11 +472,11 @@ func (h *ActionHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (*
return nil, errors.Errorf("failed to get remote source %q: %w", req.RemoteSourceName, ErrFromRemote(resp, err))
}
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil {
return nil, err
}
userSource, err := common.GetUserSource(rs, accessToken)
userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil {
return nil, err
}
@ -523,7 +515,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
return nil, errors.Errorf("failed to get user %q: %w", req.UserRef, ErrFromRemote(resp, err))
}
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
// user must be already logged in the create a linked account and can create a
// linked account only on itself.
@ -554,11 +546,11 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
switch rs.AuthType {
case cstypes.RemoteSourceAuthTypeOauth2:
oauth2Source, err := common.GetOauth2Source(rs, "")
oauth2Source, err := scommon.GetOauth2Source(rs, "")
if err != nil {
return nil, errors.Errorf("failed to create git source: %w", err)
}
token, err := common.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req)
token, err := scommon.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req)
if err != nil {
return nil, err
}
@ -572,7 +564,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
}, nil
case cstypes.RemoteSourceAuthTypePassword:
passwordSource, err := common.GetPasswordSource(rs, "")
passwordSource, err := scommon.GetPasswordSource(rs, "")
if err != nil {
return nil, errors.Errorf("failed to create git source: %w", err)
}
@ -751,7 +743,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st
return nil, errors.Errorf("failed to get remote source %q: %w", remoteSourceName, ErrFromRemote(resp, err))
}
oauth2Source, err := common.GetOauth2Source(rs, "")
oauth2Source, err := scommon.GetOauth2Source(rs, "")
if err != nil {
return nil, errors.Errorf("failed to create oauth2 source: %w", err)
}
@ -765,7 +757,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st
}
func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
if !h.IsUserAdmin(ctx) {
if !common.IsUserAdmin(ctx) {
return errors.Errorf("user not logged in")
}
@ -777,12 +769,12 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
}
func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) error {
if !h.IsUserLoggedOrAdmin(ctx) {
if !common.IsUserLoggedOrAdmin(ctx) {
return errors.Errorf("user not logged in")
}
isAdmin := !h.IsUserAdmin(ctx)
curUserID := h.CurrentUserID(ctx)
isAdmin := common.IsUserAdmin(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
if err != nil {
@ -802,12 +794,12 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string)
}
func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName string) error {
if !h.IsUserLoggedOrAdmin(ctx) {
if !common.IsUserLoggedOrAdmin(ctx) {
return errors.Errorf("user not logged in")
}
isAdmin := !h.IsUserAdmin(ctx)
curUserID := h.CurrentUserID(ctx)
isAdmin := common.IsUserAdmin(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
if err != nil {
@ -849,7 +841,7 @@ func (h *ActionHandler) UserCreateRun(ctx context.Context, req *UserCreateRunReq
prRefRegexes = append(prRefRegexes, re)
}
curUserID := h.CurrentUserID(ctx)
curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil {

View File

@ -20,6 +20,7 @@ import (
"strconv"
"agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types"
gwapitypes "agola.io/agola/services/gateway/api/types"
@ -41,11 +42,7 @@ func NewCreateOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *CreateOr
func (h *CreateOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var userID string
userIDVal := ctx.Value("userid")
if userIDVal != nil {
userID = userIDVal.(string)
}
userID := common.CurrentUserID(ctx)
var req gwapitypes.CreateOrgRequest
d := json.NewDecoder(r.Body)

View File

@ -20,6 +20,7 @@ import (
"net/url"
"agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types"
cstypes "agola.io/agola/services/configstore/types"
@ -49,12 +50,11 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return
}
userIDVal := ctx.Value("userid")
if userIDVal == nil {
userID := common.CurrentUserID(ctx)
if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return
}
userID := userIDVal.(string)
creq := &action.CreateProjectGroupRequest{
Name: req.Name,

View File

@ -19,6 +19,7 @@ import (
gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
csclient "agola.io/agola/services/configstore/client"
cstypes "agola.io/agola/services/configstore/types"
@ -53,12 +54,11 @@ func (h *UserRemoteReposHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
vars := mux.Vars(r)
remoteSourceRef := vars["remotesourceref"]
userIDVal := ctx.Value("userid")
if userIDVal == nil {
userID := common.CurrentUserID(ctx)
if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return
}
userID := userIDVal.(string)
user, resp, err := h.configstoreClient.GetUser(ctx, userID)
if httpErrorFromRemote(w, resp, err) {

View File

@ -22,6 +22,7 @@ import (
"strconv"
"agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types"
cstypes "agola.io/agola/services/configstore/types"
@ -104,12 +105,11 @@ func NewCurrentUserHandler(logger *zap.Logger, ah *action.ActionHandler) *Curren
func (h *CurrentUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userIDVal := ctx.Value("userid")
if userIDVal == nil {
userID := common.CurrentUserID(ctx)
if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return
}
userID := userIDVal.(string)
user, err := h.ah.GetUser(ctx, userID)
if httpError(w, err) {
@ -604,14 +604,13 @@ func NewUserOrgsHandler(logger *zap.Logger, ah *action.ActionHandler) *UserOrgsH
func (h *UserOrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userIDVal := ctx.Value("userid")
if userIDVal == nil {
userID := common.CurrentUserID(ctx)
if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return
}
userRef := userIDVal.(string)
userOrgs, err := h.ah.GetUserOrgs(ctx, userRef)
userOrgs, err := h.ah.GetUserOrgs(ctx, userID)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return

View File

@ -0,0 +1,36 @@
package common
import "context"
type ContextKey int
const (
ContextKeyUserID ContextKey = iota
ContextKeyUsername
ContextKeyUserAdmin
)
func CurrentUserID(ctx context.Context) string {
userIDVal := ctx.Value(ContextKeyUserID)
if userIDVal == nil {
return ""
}
return userIDVal.(string)
}
func IsUserLogged(ctx context.Context) bool {
return ctx.Value(ContextKeyUserID) != nil
}
func IsUserAdmin(ctx context.Context) bool {
isAdmin := false
isAdminVal := ctx.Value(ContextKeyUserAdmin)
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
return isAdmin
}
func IsUserLoggedOrAdmin(ctx context.Context) bool {
return IsUserLogged(ctx) || IsUserAdmin(ctx)
}

View File

@ -19,7 +19,8 @@ import (
"net/http"
"strings"
"agola.io/agola/internal/services/common"
scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
csclient "agola.io/agola/services/configstore/client"
"github.com/golang-jwt/jwt/v4"
@ -35,12 +36,12 @@ type AuthHandler struct {
configstoreClient *csclient.Client
adminToken string
sd *common.TokenSigningData
sd *scommon.TokenSigningData
required bool
}
func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *common.TokenSigningData, required bool) func(http.Handler) http.Handler {
func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *scommon.TokenSigningData, required bool) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return &AuthHandler{
log: logger.Sugar(),
@ -59,7 +60,7 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tokenString, _ := TokenExtractor.ExtractToken(r)
if h.adminToken != "" && tokenString != "" {
if tokenString == h.adminToken {
ctx = context.WithValue(ctx, "admin", true)
ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
h.next.ServeHTTP(w, r.WithContext(ctx))
return
} else {
@ -74,11 +75,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// pass userid to handlers via context
ctx = context.WithValue(ctx, "userid", user.ID)
ctx = context.WithValue(ctx, "username", user.Name)
ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID)
ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name)
if user.Admin {
ctx = context.WithValue(ctx, "admin", true)
ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
}
h.next.ServeHTTP(w, r.WithContext(ctx))
@ -128,11 +129,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// pass userid and username to handlers via context
ctx = context.WithValue(ctx, "userid", user.ID)
ctx = context.WithValue(ctx, "username", user.Name)
ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID)
ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name)
if user.Admin {
ctx = context.WithValue(ctx, "admin", true)
ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
}
h.next.ServeHTTP(w, r.WithContext(ctx))