gateway: use dedicated auth context keys type

Use auth context keys dedicated type instead of strings and avoid code
duplication by moving shared code under a common package.
This commit is contained in:
Simone Gotti 2022-02-24 11:18:29 +01:00
parent a1997490e7
commit b7c797bde1
12 changed files with 132 additions and 127 deletions

View File

@ -17,44 +17,20 @@ package action
import ( import (
"context" "context"
"agola.io/agola/internal/services/common" scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
func (h *ActionHandler) CurrentUserID(ctx context.Context) string {
userIDVal := ctx.Value("userid")
if userIDVal == nil {
return ""
}
return userIDVal.(string)
}
func (h *ActionHandler) IsUserLogged(ctx context.Context) bool {
return ctx.Value("userid") != nil
}
func (h *ActionHandler) IsUserAdmin(ctx context.Context) bool {
isAdmin := false
isAdminVal := ctx.Value("admin")
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
return isAdmin
}
func (h *ActionHandler) IsUserLoggedOrAdmin(ctx context.Context) bool {
return h.IsUserLogged(ctx) || h.IsUserAdmin(ctx)
}
func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, error) { func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx) isAdmin := common.IsUserAdmin(ctx)
if isAdmin { if isAdmin {
return true, nil return true, nil
} }
userID := h.CurrentUserID(ctx) userID := common.CurrentUserID(ctx)
if userID == "" { if userID == "" {
return false, nil return false, nil
} }
@ -77,12 +53,12 @@ func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, err
} }
func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) { func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx) isAdmin := common.IsUserAdmin(ctx)
if isAdmin { if isAdmin {
return true, nil return true, nil
} }
userID := h.CurrentUserID(ctx) userID := common.CurrentUserID(ctx)
if userID == "" { if userID == "" {
return false, nil return false, nil
} }
@ -113,12 +89,12 @@ func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.Co
} }
func (h *ActionHandler) IsProjectMember(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) { func (h *ActionHandler) IsProjectMember(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) {
isAdmin := h.IsUserAdmin(ctx) isAdmin := common.IsUserAdmin(ctx)
if isAdmin { if isAdmin {
return true, nil return true, nil
} }
userID := h.CurrentUserID(ctx) userID := common.CurrentUserID(ctx)
if userID == "" { if userID == "" {
return false, nil return false, nil
} }
@ -170,7 +146,7 @@ func (h *ActionHandler) IsVariableOwner(ctx context.Context, parentType cstypes.
} }
func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, error) { func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, error) {
groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup) groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -179,7 +155,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
var ownerType cstypes.ConfigType var ownerType cstypes.ConfigType
var ownerID string var ownerID string
switch groupType { switch groupType {
case common.GroupTypeProject: case scommon.GroupTypeProject:
p, resp, err := h.configstoreClient.GetProject(ctx, groupID) p, resp, err := h.configstoreClient.GetProject(ctx, groupID)
if err != nil { if err != nil {
return false, ErrFromRemote(resp, err) return false, ErrFromRemote(resp, err)
@ -187,7 +163,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
ownerType = p.OwnerType ownerType = p.OwnerType
ownerID = p.OwnerID ownerID = p.OwnerID
visibility = p.GlobalVisibility visibility = p.GlobalVisibility
case common.GroupTypeUser: case scommon.GroupTypeUser:
// user direct runs // user direct runs
ownerType = cstypes.ConfigTypeUser ownerType = cstypes.ConfigTypeUser
ownerID = groupID ownerID = groupID
@ -208,7 +184,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e
} }
func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (bool, error) { func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (bool, error) {
groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup) groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -216,14 +192,14 @@ func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (b
var ownerType cstypes.ConfigType var ownerType cstypes.ConfigType
var ownerID string var ownerID string
switch groupType { switch groupType {
case common.GroupTypeProject: case scommon.GroupTypeProject:
p, resp, err := h.configstoreClient.GetProject(ctx, groupID) p, resp, err := h.configstoreClient.GetProject(ctx, groupID)
if err != nil { if err != nil {
return false, ErrFromRemote(resp, err) return false, ErrFromRemote(resp, err)
} }
ownerType = p.OwnerType ownerType = p.OwnerType
ownerID = p.OwnerID ownerID = p.OwnerID
case common.GroupTypeUser: case scommon.GroupTypeUser:
// user direct runs // user direct runs
ownerType = cstypes.ConfigTypeUser ownerType = cstypes.ConfigTypeUser
ownerID = groupID ownerID = groupID

View File

@ -17,6 +17,7 @@ package action
import ( import (
"context" "context"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -87,7 +88,7 @@ type CreateOrgRequest struct {
} }
func (h *ActionHandler) CreateOrg(ctx context.Context, req *CreateOrgRequest) (*cstypes.Organization, error) { func (h *ActionHandler) CreateOrg(ctx context.Context, req *CreateOrgRequest) (*cstypes.Organization, error) {
if !h.IsUserLoggedOrAdmin(ctx) { if !common.IsUserLoggedOrAdmin(ctx) {
return nil, errors.Errorf("user not logged in") return nil, errors.Errorf("user not logged in")
} }

View File

@ -22,6 +22,7 @@ import (
"path" "path"
gitsource "agola.io/agola/internal/gitsources" gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/services/types" "agola.io/agola/internal/services/types"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types" csapitypes "agola.io/agola/services/configstore/api/types"
@ -61,7 +62,7 @@ type CreateProjectRequest struct {
} }
func (h *ActionHandler) CreateProject(ctx context.Context, req *CreateProjectRequest) (*csapitypes.Project, error) { func (h *ActionHandler) CreateProject(ctx context.Context, req *CreateProjectRequest) (*csapitypes.Project, error) {
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil { if err != nil {
@ -227,7 +228,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, projectRef string, re
} }
func (h *ActionHandler) ProjectUpdateRepoLinkedAccount(ctx context.Context, projectRef string) (*csapitypes.Project, error) { func (h *ActionHandler) ProjectUpdateRepoLinkedAccount(ctx context.Context, projectRef string) (*csapitypes.Project, error) {
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil { if err != nil {
@ -428,7 +429,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er
} }
func (h *ActionHandler) ProjectCreateRun(ctx context.Context, projectRef, branch, tag, refName, commitSHA string) error { func (h *ActionHandler) ProjectCreateRun(ctx context.Context, projectRef, branch, tag, refName, commitSHA string) error {
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil { if err != nil {

View File

@ -17,6 +17,7 @@ package action
import ( import (
"context" "context"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -60,7 +61,7 @@ type CreateRemoteSourceRequest struct {
} }
func (h *ActionHandler) CreateRemoteSource(ctx context.Context, req *CreateRemoteSourceRequest) (*cstypes.RemoteSource, error) { func (h *ActionHandler) CreateRemoteSource(ctx context.Context, req *CreateRemoteSourceRequest) (*cstypes.RemoteSource, error) {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin") return nil, errors.Errorf("user not admin")
} }
@ -134,7 +135,7 @@ type UpdateRemoteSourceRequest struct {
} }
func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemoteSourceRequest) (*cstypes.RemoteSource, error) { func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemoteSourceRequest) (*cstypes.RemoteSource, error) {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin") return nil, errors.Errorf("user not admin")
} }
@ -182,7 +183,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot
} }
func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, rsRef string) error { func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, rsRef string) error {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return errors.Errorf("user not admin") return errors.Errorf("user not admin")
} }

View File

@ -24,7 +24,8 @@ import (
"agola.io/agola/internal/config" "agola.io/agola/internal/config"
gitsource "agola.io/agola/internal/gitsources" gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/runconfig" "agola.io/agola/internal/runconfig"
"agola.io/agola/internal/services/common" scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
itypes "agola.io/agola/internal/services/types" itypes "agola.io/agola/internal/services/types"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -268,7 +269,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
if !canDoRunAction { if !canDoRunAction {
return util.NewErrForbidden(errors.Errorf("user not authorized")) return util.NewErrForbidden(errors.Errorf("user not authorized"))
} }
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
if curUserID == "" { if curUserID == "" {
return util.NewErrBadRequest(errors.Errorf("no logged in user")) return util.NewErrBadRequest(errors.Errorf("no logged in user"))
} }
@ -285,7 +286,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
if rt.Annotations != nil { if rt.Annotations != nil {
annotations = rt.Annotations annotations = rt.Annotations
} }
approversAnnotation, ok := annotations[common.ApproversAnnotation] approversAnnotation, ok := annotations[scommon.ApproversAnnotation]
if ok { if ok {
if err := json.Unmarshal([]byte(approversAnnotation), &approvers); err != nil { if err := json.Unmarshal([]byte(approversAnnotation), &approvers); err != nil {
return errors.Errorf("failed to unmarshal run task approvers annotation: %w", err) return errors.Errorf("failed to unmarshal run task approvers annotation: %w", err)
@ -304,7 +305,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe
return errors.Errorf("failed to marshal run task approvers annotation: %w", err) return errors.Errorf("failed to marshal run task approvers annotation: %w", err)
} }
annotations[common.ApproversAnnotation] = string(approversj) annotations[scommon.ApproversAnnotation] = string(approversj)
rsreq := &rsapitypes.RunTaskActionsRequest{ rsreq := &rsapitypes.RunTaskActionsRequest{
ActionType: rsapitypes.RunTaskActionTypeSetAnnotations, ActionType: rsapitypes.RunTaskActionTypeSetAnnotations,
@ -372,32 +373,32 @@ func (h *ActionHandler) CreateRuns(ctx context.Context, req *CreateRunRequest) e
return util.NewErrBadRequest(errors.Errorf("empty message")) return util.NewErrBadRequest(errors.Errorf("empty message"))
} }
var baseGroupType common.GroupType var baseGroupType scommon.GroupType
var baseGroupID string var baseGroupID string
var groupType common.GroupType var groupType scommon.GroupType
var group string var group string
if req.RunType == itypes.RunTypeProject { if req.RunType == itypes.RunTypeProject {
baseGroupType = common.GroupTypeProject baseGroupType = scommon.GroupTypeProject
baseGroupID = req.Project.ID baseGroupID = req.Project.ID
} else { } else {
baseGroupType = common.GroupTypeUser baseGroupType = scommon.GroupTypeUser
baseGroupID = req.User.ID baseGroupID = req.User.ID
} }
switch req.RefType { switch req.RefType {
case itypes.RunRefTypeBranch: case itypes.RunRefTypeBranch:
groupType = common.GroupTypeBranch groupType = scommon.GroupTypeBranch
group = req.Branch group = req.Branch
case itypes.RunRefTypeTag: case itypes.RunRefTypeTag:
groupType = common.GroupTypeTag groupType = scommon.GroupTypeTag
group = req.Tag group = req.Tag
case itypes.RunRefTypePullRequest: case itypes.RunRefTypePullRequest:
groupType = common.GroupTypePullRequest groupType = scommon.GroupTypePullRequest
group = req.PullRequestID group = req.PullRequestID
} }
runGroup := common.GenRunGroup(baseGroupType, baseGroupID, groupType, group) runGroup := scommon.GenRunGroup(baseGroupType, baseGroupID, groupType, group)
gitURL, err := util.ParseGitURL(req.CloneURL) gitURL, err := util.ParseGitURL(req.CloneURL)
if err != nil { if err != nil {
@ -595,7 +596,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque
} }
// remove overriden variables // remove overriden variables
pvars = common.FilterOverriddenVariables(pvars) pvars = scommon.FilterOverriddenVariables(pvars)
// get project secrets // get project secrets
secrets, _, err := h.configstoreClient.GetProjectSecrets(ctx, req.Project.ID, true) secrets, _, err := h.configstoreClient.GetProjectSecrets(ctx, req.Project.ID, true)
@ -611,7 +612,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque
continue continue
} }
// get the secret value referenced by the variable, it must be a secret at the same level or a lower level // get the secret value referenced by the variable, it must be a secret at the same level or a lower level
secret := common.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets) secret := scommon.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets)
if secret != nil { if secret != nil {
varValue, ok := secret.Data[varval.SecretVar] varValue, ok := secret.Data[varval.SecretVar]
if ok { if ok {

View File

@ -24,7 +24,8 @@ import (
gitsource "agola.io/agola/internal/gitsources" gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/gitsources/agolagit" "agola.io/agola/internal/gitsources/agolagit"
"agola.io/agola/internal/services/common" scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/services/types" "agola.io/agola/internal/services/types"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types" csapitypes "agola.io/agola/services/configstore/api/types"
@ -46,7 +47,7 @@ func isAccessTokenExpired(expiresAt time.Time) bool {
} }
func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.User, error) { func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.User, error) {
if !h.IsUserLoggedOrAdmin(ctx) { if !common.IsUserLoggedOrAdmin(ctx) {
return nil, errors.Errorf("user not logged in") return nil, errors.Errorf("user not logged in")
} }
@ -58,7 +59,7 @@ func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.U
} }
func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*csapitypes.UserOrgsResponse, error) { func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*csapitypes.UserOrgsResponse, error) {
if !h.IsUserLogged(ctx) { if !common.IsUserLogged(ctx) {
return nil, errors.Errorf("user not logged in") return nil, errors.Errorf("user not logged in")
} }
@ -76,7 +77,7 @@ type GetUsersRequest struct {
} }
func (h *ActionHandler) GetUsers(ctx context.Context, req *GetUsersRequest) ([]*cstypes.User, error) { func (h *ActionHandler) GetUsers(ctx context.Context, req *GetUsersRequest) ([]*cstypes.User, error) {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not logged in") return nil, errors.Errorf("user not logged in")
} }
@ -92,7 +93,7 @@ type CreateUserRequest struct {
} }
func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) (*cstypes.User, error) { func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) (*cstypes.User, error) {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return nil, errors.Errorf("user not admin") return nil, errors.Errorf("user not admin")
} }
@ -123,17 +124,8 @@ type CreateUserTokenRequest struct {
} }
func (h *ActionHandler) CreateUserToken(ctx context.Context, req *CreateUserTokenRequest) (string, error) { func (h *ActionHandler) CreateUserToken(ctx context.Context, req *CreateUserTokenRequest) (string, error) {
var userID string isAdmin := common.IsUserAdmin(ctx)
userIDVal := ctx.Value("userid") userID := common.CurrentUserID(ctx)
if userIDVal != nil {
userID = userIDVal.(string)
}
isAdmin := false
isAdminVal := ctx.Value("admin")
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
userRef := req.UserRef userRef := req.UserRef
user, resp, err := h.configstoreClient.GetUser(ctx, userRef) user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
@ -193,11 +185,11 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userRef, rs.Name)) return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userRef, rs.Name))
} }
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userSource, err := common.GetUserSource(rs, accessToken) userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -271,7 +263,7 @@ func (h *ActionHandler) RefreshLinkedAccount(ctx context.Context, rs *cstypes.Re
case cstypes.RemoteSourceAuthTypeOauth2: case cstypes.RemoteSourceAuthTypeOauth2:
// refresh access token if expired // refresh access token if expired
if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) { if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) {
userSource, err := common.GetOauth2Source(rs, "") userSource, err := scommon.GetOauth2Source(rs, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -301,7 +293,7 @@ func (h *ActionHandler) GetGitSource(ctx context.Context, rs *cstypes.RemoteSour
if err != nil { if err != nil {
return nil, err return nil, err
} }
return common.GetGitSource(rs, la) return scommon.GetGitSource(rs, la)
} }
type RegisterUserRequest struct { type RegisterUserRequest struct {
@ -329,11 +321,11 @@ func (h *ActionHandler) RegisterUser(ctx context.Context, req *RegisterUserReque
return nil, util.NewErrBadRequest(errors.Errorf("remote source user registration is disabled")) return nil, util.NewErrBadRequest(errors.Errorf("remote source user registration is disabled"))
} }
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userSource, err := common.GetUserSource(rs, accessToken) userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -391,11 +383,11 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (*
return nil, util.NewErrBadRequest(errors.Errorf("remote source user login is disabled")) return nil, util.NewErrBadRequest(errors.Errorf("remote source user login is disabled"))
} }
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userSource, err := common.GetUserSource(rs, accessToken) userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -451,7 +443,7 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (*
} }
// generate jwt token // generate jwt token
token, err := common.GenerateLoginJWTToken(h.sd, user.ID) token, err := scommon.GenerateLoginJWTToken(h.sd, user.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -480,11 +472,11 @@ func (h *ActionHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (*
return nil, errors.Errorf("failed to get remote source %q: %w", req.RemoteSourceName, ErrFromRemote(resp, err)) return nil, errors.Errorf("failed to get remote source %q: %w", req.RemoteSourceName, ErrFromRemote(resp, err))
} }
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userSource, err := common.GetUserSource(rs, accessToken) userSource, err := scommon.GetUserSource(rs, accessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -523,7 +515,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
return nil, errors.Errorf("failed to get user %q: %w", req.UserRef, ErrFromRemote(resp, err)) return nil, errors.Errorf("failed to get user %q: %w", req.UserRef, ErrFromRemote(resp, err))
} }
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
// user must be already logged in the create a linked account and can create a // user must be already logged in the create a linked account and can create a
// linked account only on itself. // linked account only on itself.
@ -554,11 +546,11 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
switch rs.AuthType { switch rs.AuthType {
case cstypes.RemoteSourceAuthTypeOauth2: case cstypes.RemoteSourceAuthTypeOauth2:
oauth2Source, err := common.GetOauth2Source(rs, "") oauth2Source, err := scommon.GetOauth2Source(rs, "")
if err != nil { if err != nil {
return nil, errors.Errorf("failed to create git source: %w", err) return nil, errors.Errorf("failed to create git source: %w", err)
} }
token, err := common.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req) token, err := scommon.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -572,7 +564,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource
}, nil }, nil
case cstypes.RemoteSourceAuthTypePassword: case cstypes.RemoteSourceAuthTypePassword:
passwordSource, err := common.GetPasswordSource(rs, "") passwordSource, err := scommon.GetPasswordSource(rs, "")
if err != nil { if err != nil {
return nil, errors.Errorf("failed to create git source: %w", err) return nil, errors.Errorf("failed to create git source: %w", err)
} }
@ -751,7 +743,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st
return nil, errors.Errorf("failed to get remote source %q: %w", remoteSourceName, ErrFromRemote(resp, err)) return nil, errors.Errorf("failed to get remote source %q: %w", remoteSourceName, ErrFromRemote(resp, err))
} }
oauth2Source, err := common.GetOauth2Source(rs, "") oauth2Source, err := scommon.GetOauth2Source(rs, "")
if err != nil { if err != nil {
return nil, errors.Errorf("failed to create oauth2 source: %w", err) return nil, errors.Errorf("failed to create oauth2 source: %w", err)
} }
@ -765,7 +757,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st
} }
func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error { func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
if !h.IsUserAdmin(ctx) { if !common.IsUserAdmin(ctx) {
return errors.Errorf("user not logged in") return errors.Errorf("user not logged in")
} }
@ -777,12 +769,12 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
} }
func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) error { func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) error {
if !h.IsUserLoggedOrAdmin(ctx) { if !common.IsUserLoggedOrAdmin(ctx) {
return errors.Errorf("user not logged in") return errors.Errorf("user not logged in")
} }
isAdmin := !h.IsUserAdmin(ctx) isAdmin := common.IsUserAdmin(ctx)
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, userRef) user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
if err != nil { if err != nil {
@ -802,12 +794,12 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string)
} }
func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName string) error { func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName string) error {
if !h.IsUserLoggedOrAdmin(ctx) { if !common.IsUserLoggedOrAdmin(ctx) {
return errors.Errorf("user not logged in") return errors.Errorf("user not logged in")
} }
isAdmin := !h.IsUserAdmin(ctx) isAdmin := common.IsUserAdmin(ctx)
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, userRef) user, resp, err := h.configstoreClient.GetUser(ctx, userRef)
if err != nil { if err != nil {
@ -849,7 +841,7 @@ func (h *ActionHandler) UserCreateRun(ctx context.Context, req *UserCreateRunReq
prRefRegexes = append(prRefRegexes, re) prRefRegexes = append(prRefRegexes, re)
} }
curUserID := h.CurrentUserID(ctx) curUserID := common.CurrentUserID(ctx)
user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID)
if err != nil { if err != nil {

View File

@ -20,6 +20,7 @@ import (
"strconv" "strconv"
"agola.io/agola/internal/services/gateway/action" "agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
gwapitypes "agola.io/agola/services/gateway/api/types" gwapitypes "agola.io/agola/services/gateway/api/types"
@ -41,11 +42,7 @@ func NewCreateOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *CreateOr
func (h *CreateOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *CreateOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
var userID string userID := common.CurrentUserID(ctx)
userIDVal := ctx.Value("userid")
if userIDVal != nil {
userID = userIDVal.(string)
}
var req gwapitypes.CreateOrgRequest var req gwapitypes.CreateOrgRequest
d := json.NewDecoder(r.Body) d := json.NewDecoder(r.Body)

View File

@ -20,6 +20,7 @@ import (
"net/url" "net/url"
"agola.io/agola/internal/services/gateway/action" "agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types" csapitypes "agola.io/agola/services/configstore/api/types"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -49,12 +50,11 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return return
} }
userIDVal := ctx.Value("userid") userID := common.CurrentUserID(ctx)
if userIDVal == nil { if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return return
} }
userID := userIDVal.(string)
creq := &action.CreateProjectGroupRequest{ creq := &action.CreateProjectGroupRequest{
Name: req.Name, Name: req.Name,

View File

@ -19,6 +19,7 @@ import (
gitsource "agola.io/agola/internal/gitsources" gitsource "agola.io/agola/internal/gitsources"
"agola.io/agola/internal/services/gateway/action" "agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
csclient "agola.io/agola/services/configstore/client" csclient "agola.io/agola/services/configstore/client"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -53,12 +54,11 @@ func (h *UserRemoteReposHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
vars := mux.Vars(r) vars := mux.Vars(r)
remoteSourceRef := vars["remotesourceref"] remoteSourceRef := vars["remotesourceref"]
userIDVal := ctx.Value("userid") userID := common.CurrentUserID(ctx)
if userIDVal == nil { if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return return
} }
userID := userIDVal.(string)
user, resp, err := h.configstoreClient.GetUser(ctx, userID) user, resp, err := h.configstoreClient.GetUser(ctx, userID)
if httpErrorFromRemote(w, resp, err) { if httpErrorFromRemote(w, resp, err) {

View File

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"agola.io/agola/internal/services/gateway/action" "agola.io/agola/internal/services/gateway/action"
"agola.io/agola/internal/services/gateway/common"
"agola.io/agola/internal/util" "agola.io/agola/internal/util"
csapitypes "agola.io/agola/services/configstore/api/types" csapitypes "agola.io/agola/services/configstore/api/types"
cstypes "agola.io/agola/services/configstore/types" cstypes "agola.io/agola/services/configstore/types"
@ -104,12 +105,11 @@ func NewCurrentUserHandler(logger *zap.Logger, ah *action.ActionHandler) *Curren
func (h *CurrentUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *CurrentUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
userIDVal := ctx.Value("userid") userID := common.CurrentUserID(ctx)
if userIDVal == nil { if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return return
} }
userID := userIDVal.(string)
user, err := h.ah.GetUser(ctx, userID) user, err := h.ah.GetUser(ctx, userID)
if httpError(w, err) { if httpError(w, err) {
@ -604,14 +604,13 @@ func NewUserOrgsHandler(logger *zap.Logger, ah *action.ActionHandler) *UserOrgsH
func (h *UserOrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *UserOrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
userIDVal := ctx.Value("userid") userID := common.CurrentUserID(ctx)
if userIDVal == nil { if userID == "" {
httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated")))
return return
} }
userRef := userIDVal.(string)
userOrgs, err := h.ah.GetUserOrgs(ctx, userRef) userOrgs, err := h.ah.GetUserOrgs(ctx, userID)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return

View File

@ -0,0 +1,36 @@
package common
import "context"
type ContextKey int
const (
ContextKeyUserID ContextKey = iota
ContextKeyUsername
ContextKeyUserAdmin
)
func CurrentUserID(ctx context.Context) string {
userIDVal := ctx.Value(ContextKeyUserID)
if userIDVal == nil {
return ""
}
return userIDVal.(string)
}
func IsUserLogged(ctx context.Context) bool {
return ctx.Value(ContextKeyUserID) != nil
}
func IsUserAdmin(ctx context.Context) bool {
isAdmin := false
isAdminVal := ctx.Value(ContextKeyUserAdmin)
if isAdminVal != nil {
isAdmin = isAdminVal.(bool)
}
return isAdmin
}
func IsUserLoggedOrAdmin(ctx context.Context) bool {
return IsUserLogged(ctx) || IsUserAdmin(ctx)
}

View File

@ -19,7 +19,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"agola.io/agola/internal/services/common" scommon "agola.io/agola/internal/services/common"
"agola.io/agola/internal/services/gateway/common"
csclient "agola.io/agola/services/configstore/client" csclient "agola.io/agola/services/configstore/client"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
@ -35,12 +36,12 @@ type AuthHandler struct {
configstoreClient *csclient.Client configstoreClient *csclient.Client
adminToken string adminToken string
sd *common.TokenSigningData sd *scommon.TokenSigningData
required bool required bool
} }
func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *common.TokenSigningData, required bool) func(http.Handler) http.Handler { func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *scommon.TokenSigningData, required bool) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler {
return &AuthHandler{ return &AuthHandler{
log: logger.Sugar(), log: logger.Sugar(),
@ -59,7 +60,7 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
tokenString, _ := TokenExtractor.ExtractToken(r) tokenString, _ := TokenExtractor.ExtractToken(r)
if h.adminToken != "" && tokenString != "" { if h.adminToken != "" && tokenString != "" {
if tokenString == h.adminToken { if tokenString == h.adminToken {
ctx = context.WithValue(ctx, "admin", true) ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
h.next.ServeHTTP(w, r.WithContext(ctx)) h.next.ServeHTTP(w, r.WithContext(ctx))
return return
} else { } else {
@ -74,11 +75,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// pass userid to handlers via context // pass userid to handlers via context
ctx = context.WithValue(ctx, "userid", user.ID) ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID)
ctx = context.WithValue(ctx, "username", user.Name) ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name)
if user.Admin { if user.Admin {
ctx = context.WithValue(ctx, "admin", true) ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
} }
h.next.ServeHTTP(w, r.WithContext(ctx)) h.next.ServeHTTP(w, r.WithContext(ctx))
@ -128,11 +129,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// pass userid and username to handlers via context // pass userid and username to handlers via context
ctx = context.WithValue(ctx, "userid", user.ID) ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID)
ctx = context.WithValue(ctx, "username", user.Name) ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name)
if user.Admin { if user.Admin {
ctx = context.WithValue(ctx, "admin", true) ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true)
} }
h.next.ServeHTTP(w, r.WithContext(ctx)) h.next.ServeHTTP(w, r.WithContext(ctx))