diff --git a/internal/services/gateway/action/auth.go b/internal/services/gateway/action/auth.go index bc585d3..1bbbd36 100644 --- a/internal/services/gateway/action/auth.go +++ b/internal/services/gateway/action/auth.go @@ -17,44 +17,20 @@ package action import ( "context" - "agola.io/agola/internal/services/common" + scommon "agola.io/agola/internal/services/common" + "agola.io/agola/internal/services/gateway/common" cstypes "agola.io/agola/services/configstore/types" errors "golang.org/x/xerrors" ) -func (h *ActionHandler) CurrentUserID(ctx context.Context) string { - userIDVal := ctx.Value("userid") - if userIDVal == nil { - return "" - } - return userIDVal.(string) -} - -func (h *ActionHandler) IsUserLogged(ctx context.Context) bool { - return ctx.Value("userid") != nil -} - -func (h *ActionHandler) IsUserAdmin(ctx context.Context) bool { - isAdmin := false - isAdminVal := ctx.Value("admin") - if isAdminVal != nil { - isAdmin = isAdminVal.(bool) - } - return isAdmin -} - -func (h *ActionHandler) IsUserLoggedOrAdmin(ctx context.Context) bool { - return h.IsUserLogged(ctx) || h.IsUserAdmin(ctx) -} - func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, error) { - isAdmin := h.IsUserAdmin(ctx) + isAdmin := common.IsUserAdmin(ctx) if isAdmin { return true, nil } - userID := h.CurrentUserID(ctx) + userID := common.CurrentUserID(ctx) if userID == "" { return false, nil } @@ -77,12 +53,12 @@ func (h *ActionHandler) IsOrgOwner(ctx context.Context, orgID string) (bool, err } func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) { - isAdmin := h.IsUserAdmin(ctx) + isAdmin := common.IsUserAdmin(ctx) if isAdmin { return true, nil } - userID := h.CurrentUserID(ctx) + userID := common.CurrentUserID(ctx) if userID == "" { return false, nil } @@ -113,12 +89,12 @@ func (h *ActionHandler) IsProjectOwner(ctx context.Context, ownerType cstypes.Co } func (h *ActionHandler) IsProjectMember(ctx context.Context, ownerType cstypes.ConfigType, ownerID string) (bool, error) { - isAdmin := h.IsUserAdmin(ctx) + isAdmin := common.IsUserAdmin(ctx) if isAdmin { return true, nil } - userID := h.CurrentUserID(ctx) + userID := common.CurrentUserID(ctx) if userID == "" { return false, nil } @@ -170,7 +146,7 @@ func (h *ActionHandler) IsVariableOwner(ctx context.Context, parentType cstypes. } func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, error) { - groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup) + groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup) if err != nil { return false, err } @@ -179,7 +155,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e var ownerType cstypes.ConfigType var ownerID string switch groupType { - case common.GroupTypeProject: + case scommon.GroupTypeProject: p, resp, err := h.configstoreClient.GetProject(ctx, groupID) if err != nil { return false, ErrFromRemote(resp, err) @@ -187,7 +163,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e ownerType = p.OwnerType ownerID = p.OwnerID visibility = p.GlobalVisibility - case common.GroupTypeUser: + case scommon.GroupTypeUser: // user direct runs ownerType = cstypes.ConfigTypeUser ownerID = groupID @@ -208,7 +184,7 @@ func (h *ActionHandler) CanGetRun(ctx context.Context, runGroup string) (bool, e } func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (bool, error) { - groupType, groupID, err := common.GroupTypeIDFromRunGroup(runGroup) + groupType, groupID, err := scommon.GroupTypeIDFromRunGroup(runGroup) if err != nil { return false, err } @@ -216,14 +192,14 @@ func (h *ActionHandler) CanDoRunActions(ctx context.Context, runGroup string) (b var ownerType cstypes.ConfigType var ownerID string switch groupType { - case common.GroupTypeProject: + case scommon.GroupTypeProject: p, resp, err := h.configstoreClient.GetProject(ctx, groupID) if err != nil { return false, ErrFromRemote(resp, err) } ownerType = p.OwnerType ownerID = p.OwnerID - case common.GroupTypeUser: + case scommon.GroupTypeUser: // user direct runs ownerType = cstypes.ConfigTypeUser ownerID = groupID diff --git a/internal/services/gateway/action/org.go b/internal/services/gateway/action/org.go index c2135ca..4faaf4f 100644 --- a/internal/services/gateway/action/org.go +++ b/internal/services/gateway/action/org.go @@ -17,6 +17,7 @@ package action import ( "context" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" cstypes "agola.io/agola/services/configstore/types" @@ -87,7 +88,7 @@ type CreateOrgRequest struct { } func (h *ActionHandler) CreateOrg(ctx context.Context, req *CreateOrgRequest) (*cstypes.Organization, error) { - if !h.IsUserLoggedOrAdmin(ctx) { + if !common.IsUserLoggedOrAdmin(ctx) { return nil, errors.Errorf("user not logged in") } diff --git a/internal/services/gateway/action/project.go b/internal/services/gateway/action/project.go index 3f62d81..7f4be86 100644 --- a/internal/services/gateway/action/project.go +++ b/internal/services/gateway/action/project.go @@ -22,6 +22,7 @@ import ( "path" gitsource "agola.io/agola/internal/gitsources" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/services/types" "agola.io/agola/internal/util" csapitypes "agola.io/agola/services/configstore/api/types" @@ -61,7 +62,7 @@ type CreateProjectRequest struct { } func (h *ActionHandler) CreateProject(ctx context.Context, req *CreateProjectRequest) (*csapitypes.Project, error) { - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) if err != nil { @@ -227,7 +228,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, projectRef string, re } func (h *ActionHandler) ProjectUpdateRepoLinkedAccount(ctx context.Context, projectRef string) (*csapitypes.Project, error) { - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) if err != nil { @@ -428,7 +429,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er } func (h *ActionHandler) ProjectCreateRun(ctx context.Context, projectRef, branch, tag, refName, commitSHA string) error { - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) if err != nil { diff --git a/internal/services/gateway/action/remotesource.go b/internal/services/gateway/action/remotesource.go index b3239aa..4f2cbc0 100644 --- a/internal/services/gateway/action/remotesource.go +++ b/internal/services/gateway/action/remotesource.go @@ -17,6 +17,7 @@ package action import ( "context" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" cstypes "agola.io/agola/services/configstore/types" @@ -60,7 +61,7 @@ type CreateRemoteSourceRequest struct { } func (h *ActionHandler) CreateRemoteSource(ctx context.Context, req *CreateRemoteSourceRequest) (*cstypes.RemoteSource, error) { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return nil, errors.Errorf("user not admin") } @@ -134,7 +135,7 @@ type UpdateRemoteSourceRequest struct { } func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemoteSourceRequest) (*cstypes.RemoteSource, error) { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return nil, errors.Errorf("user not admin") } @@ -182,7 +183,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot } func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, rsRef string) error { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return errors.Errorf("user not admin") } diff --git a/internal/services/gateway/action/run.go b/internal/services/gateway/action/run.go index 38962f8..de0cadb 100644 --- a/internal/services/gateway/action/run.go +++ b/internal/services/gateway/action/run.go @@ -24,7 +24,8 @@ import ( "agola.io/agola/internal/config" gitsource "agola.io/agola/internal/gitsources" "agola.io/agola/internal/runconfig" - "agola.io/agola/internal/services/common" + scommon "agola.io/agola/internal/services/common" + "agola.io/agola/internal/services/gateway/common" itypes "agola.io/agola/internal/services/types" "agola.io/agola/internal/util" cstypes "agola.io/agola/services/configstore/types" @@ -268,7 +269,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe if !canDoRunAction { return util.NewErrForbidden(errors.Errorf("user not authorized")) } - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) if curUserID == "" { return util.NewErrBadRequest(errors.Errorf("no logged in user")) } @@ -285,7 +286,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe if rt.Annotations != nil { annotations = rt.Annotations } - approversAnnotation, ok := annotations[common.ApproversAnnotation] + approversAnnotation, ok := annotations[scommon.ApproversAnnotation] if ok { if err := json.Unmarshal([]byte(approversAnnotation), &approvers); err != nil { return errors.Errorf("failed to unmarshal run task approvers annotation: %w", err) @@ -304,7 +305,7 @@ func (h *ActionHandler) RunTaskAction(ctx context.Context, req *RunTaskActionsRe return errors.Errorf("failed to marshal run task approvers annotation: %w", err) } - annotations[common.ApproversAnnotation] = string(approversj) + annotations[scommon.ApproversAnnotation] = string(approversj) rsreq := &rsapitypes.RunTaskActionsRequest{ ActionType: rsapitypes.RunTaskActionTypeSetAnnotations, @@ -372,32 +373,32 @@ func (h *ActionHandler) CreateRuns(ctx context.Context, req *CreateRunRequest) e return util.NewErrBadRequest(errors.Errorf("empty message")) } - var baseGroupType common.GroupType + var baseGroupType scommon.GroupType var baseGroupID string - var groupType common.GroupType + var groupType scommon.GroupType var group string if req.RunType == itypes.RunTypeProject { - baseGroupType = common.GroupTypeProject + baseGroupType = scommon.GroupTypeProject baseGroupID = req.Project.ID } else { - baseGroupType = common.GroupTypeUser + baseGroupType = scommon.GroupTypeUser baseGroupID = req.User.ID } switch req.RefType { case itypes.RunRefTypeBranch: - groupType = common.GroupTypeBranch + groupType = scommon.GroupTypeBranch group = req.Branch case itypes.RunRefTypeTag: - groupType = common.GroupTypeTag + groupType = scommon.GroupTypeTag group = req.Tag case itypes.RunRefTypePullRequest: - groupType = common.GroupTypePullRequest + groupType = scommon.GroupTypePullRequest group = req.PullRequestID } - runGroup := common.GenRunGroup(baseGroupType, baseGroupID, groupType, group) + runGroup := scommon.GenRunGroup(baseGroupType, baseGroupID, groupType, group) gitURL, err := util.ParseGitURL(req.CloneURL) if err != nil { @@ -595,7 +596,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque } // remove overriden variables - pvars = common.FilterOverriddenVariables(pvars) + pvars = scommon.FilterOverriddenVariables(pvars) // get project secrets secrets, _, err := h.configstoreClient.GetProjectSecrets(ctx, req.Project.ID, true) @@ -611,7 +612,7 @@ func (h *ActionHandler) genRunVariables(ctx context.Context, req *CreateRunReque continue } // get the secret value referenced by the variable, it must be a secret at the same level or a lower level - secret := common.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets) + secret := scommon.GetVarValueMatchingSecret(varval, pvar.ParentPath, secrets) if secret != nil { varValue, ok := secret.Data[varval.SecretVar] if ok { diff --git a/internal/services/gateway/action/user.go b/internal/services/gateway/action/user.go index 730585f..7e6ff05 100644 --- a/internal/services/gateway/action/user.go +++ b/internal/services/gateway/action/user.go @@ -24,7 +24,8 @@ import ( gitsource "agola.io/agola/internal/gitsources" "agola.io/agola/internal/gitsources/agolagit" - "agola.io/agola/internal/services/common" + scommon "agola.io/agola/internal/services/common" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/services/types" "agola.io/agola/internal/util" csapitypes "agola.io/agola/services/configstore/api/types" @@ -46,7 +47,7 @@ func isAccessTokenExpired(expiresAt time.Time) bool { } func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.User, error) { - if !h.IsUserLoggedOrAdmin(ctx) { + if !common.IsUserLoggedOrAdmin(ctx) { return nil, errors.Errorf("user not logged in") } @@ -58,7 +59,7 @@ func (h *ActionHandler) GetUser(ctx context.Context, userRef string) (*cstypes.U } func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*csapitypes.UserOrgsResponse, error) { - if !h.IsUserLogged(ctx) { + if !common.IsUserLogged(ctx) { return nil, errors.Errorf("user not logged in") } @@ -76,7 +77,7 @@ type GetUsersRequest struct { } func (h *ActionHandler) GetUsers(ctx context.Context, req *GetUsersRequest) ([]*cstypes.User, error) { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return nil, errors.Errorf("user not logged in") } @@ -92,7 +93,7 @@ type CreateUserRequest struct { } func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) (*cstypes.User, error) { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return nil, errors.Errorf("user not admin") } @@ -123,17 +124,8 @@ type CreateUserTokenRequest struct { } func (h *ActionHandler) CreateUserToken(ctx context.Context, req *CreateUserTokenRequest) (string, error) { - var userID string - userIDVal := ctx.Value("userid") - if userIDVal != nil { - userID = userIDVal.(string) - } - - isAdmin := false - isAdminVal := ctx.Value("admin") - if isAdminVal != nil { - isAdmin = isAdminVal.(bool) - } + isAdmin := common.IsUserAdmin(ctx) + userID := common.CurrentUserID(ctx) userRef := req.UserRef user, resp, err := h.configstoreClient.GetUser(ctx, userRef) @@ -193,11 +185,11 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userRef, rs.Name)) } - accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } - userSource, err := common.GetUserSource(rs, accessToken) + userSource, err := scommon.GetUserSource(rs, accessToken) if err != nil { return nil, err } @@ -271,7 +263,7 @@ func (h *ActionHandler) RefreshLinkedAccount(ctx context.Context, rs *cstypes.Re case cstypes.RemoteSourceAuthTypeOauth2: // refresh access token if expired if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) { - userSource, err := common.GetOauth2Source(rs, "") + userSource, err := scommon.GetOauth2Source(rs, "") if err != nil { return nil, err } @@ -301,7 +293,7 @@ func (h *ActionHandler) GetGitSource(ctx context.Context, rs *cstypes.RemoteSour if err != nil { return nil, err } - return common.GetGitSource(rs, la) + return scommon.GetGitSource(rs, la) } type RegisterUserRequest struct { @@ -329,11 +321,11 @@ func (h *ActionHandler) RegisterUser(ctx context.Context, req *RegisterUserReque return nil, util.NewErrBadRequest(errors.Errorf("remote source user registration is disabled")) } - accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } - userSource, err := common.GetUserSource(rs, accessToken) + userSource, err := scommon.GetUserSource(rs, accessToken) if err != nil { return nil, err } @@ -391,11 +383,11 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (* return nil, util.NewErrBadRequest(errors.Errorf("remote source user login is disabled")) } - accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } - userSource, err := common.GetUserSource(rs, accessToken) + userSource, err := scommon.GetUserSource(rs, accessToken) if err != nil { return nil, err } @@ -451,7 +443,7 @@ func (h *ActionHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (* } // generate jwt token - token, err := common.GenerateLoginJWTToken(h.sd, user.ID) + token, err := scommon.GenerateLoginJWTToken(h.sd, user.ID) if err != nil { return nil, err } @@ -480,11 +472,11 @@ func (h *ActionHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (* return nil, errors.Errorf("failed to get remote source %q: %w", req.RemoteSourceName, ErrFromRemote(resp, err)) } - accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := scommon.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } - userSource, err := common.GetUserSource(rs, accessToken) + userSource, err := scommon.GetUserSource(rs, accessToken) if err != nil { return nil, err } @@ -523,7 +515,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource return nil, errors.Errorf("failed to get user %q: %w", req.UserRef, ErrFromRemote(resp, err)) } - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) // user must be already logged in the create a linked account and can create a // linked account only on itself. @@ -554,11 +546,11 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource switch rs.AuthType { case cstypes.RemoteSourceAuthTypeOauth2: - oauth2Source, err := common.GetOauth2Source(rs, "") + oauth2Source, err := scommon.GetOauth2Source(rs, "") if err != nil { return nil, errors.Errorf("failed to create git source: %w", err) } - token, err := common.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req) + token, err := scommon.GenerateOauth2JWTToken(h.sd, rs.Name, string(requestType), req) if err != nil { return nil, err } @@ -572,7 +564,7 @@ func (h *ActionHandler) HandleRemoteSourceAuth(ctx context.Context, remoteSource }, nil case cstypes.RemoteSourceAuthTypePassword: - passwordSource, err := common.GetPasswordSource(rs, "") + passwordSource, err := scommon.GetPasswordSource(rs, "") if err != nil { return nil, errors.Errorf("failed to create git source: %w", err) } @@ -751,7 +743,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st return nil, errors.Errorf("failed to get remote source %q: %w", remoteSourceName, ErrFromRemote(resp, err)) } - oauth2Source, err := common.GetOauth2Source(rs, "") + oauth2Source, err := scommon.GetOauth2Source(rs, "") if err != nil { return nil, errors.Errorf("failed to create oauth2 source: %w", err) } @@ -765,7 +757,7 @@ func (h *ActionHandler) HandleOauth2Callback(ctx context.Context, code, state st } func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error { - if !h.IsUserAdmin(ctx) { + if !common.IsUserAdmin(ctx) { return errors.Errorf("user not logged in") } @@ -777,12 +769,12 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error { } func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) error { - if !h.IsUserLoggedOrAdmin(ctx) { + if !common.IsUserLoggedOrAdmin(ctx) { return errors.Errorf("user not logged in") } - isAdmin := !h.IsUserAdmin(ctx) - curUserID := h.CurrentUserID(ctx) + isAdmin := common.IsUserAdmin(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, userRef) if err != nil { @@ -802,12 +794,12 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) } func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName string) error { - if !h.IsUserLoggedOrAdmin(ctx) { + if !common.IsUserLoggedOrAdmin(ctx) { return errors.Errorf("user not logged in") } - isAdmin := !h.IsUserAdmin(ctx) - curUserID := h.CurrentUserID(ctx) + isAdmin := common.IsUserAdmin(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, userRef) if err != nil { @@ -849,7 +841,7 @@ func (h *ActionHandler) UserCreateRun(ctx context.Context, req *UserCreateRunReq prRefRegexes = append(prRefRegexes, re) } - curUserID := h.CurrentUserID(ctx) + curUserID := common.CurrentUserID(ctx) user, resp, err := h.configstoreClient.GetUser(ctx, curUserID) if err != nil { diff --git a/internal/services/gateway/api/org.go b/internal/services/gateway/api/org.go index ddb9ec6..45346a3 100644 --- a/internal/services/gateway/api/org.go +++ b/internal/services/gateway/api/org.go @@ -20,6 +20,7 @@ import ( "strconv" "agola.io/agola/internal/services/gateway/action" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" cstypes "agola.io/agola/services/configstore/types" gwapitypes "agola.io/agola/services/gateway/api/types" @@ -41,11 +42,7 @@ func NewCreateOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *CreateOr func (h *CreateOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - var userID string - userIDVal := ctx.Value("userid") - if userIDVal != nil { - userID = userIDVal.(string) - } + userID := common.CurrentUserID(ctx) var req gwapitypes.CreateOrgRequest d := json.NewDecoder(r.Body) diff --git a/internal/services/gateway/api/projectgroup.go b/internal/services/gateway/api/projectgroup.go index 4110283..8b9e194 100644 --- a/internal/services/gateway/api/projectgroup.go +++ b/internal/services/gateway/api/projectgroup.go @@ -20,6 +20,7 @@ import ( "net/url" "agola.io/agola/internal/services/gateway/action" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" csapitypes "agola.io/agola/services/configstore/api/types" cstypes "agola.io/agola/services/configstore/types" @@ -49,12 +50,11 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req return } - userIDVal := ctx.Value("userid") - if userIDVal == nil { + userID := common.CurrentUserID(ctx) + if userID == "" { httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) return } - userID := userIDVal.(string) creq := &action.CreateProjectGroupRequest{ Name: req.Name, diff --git a/internal/services/gateway/api/remoterepo.go b/internal/services/gateway/api/remoterepo.go index a8aa724..9569cca 100644 --- a/internal/services/gateway/api/remoterepo.go +++ b/internal/services/gateway/api/remoterepo.go @@ -19,6 +19,7 @@ import ( gitsource "agola.io/agola/internal/gitsources" "agola.io/agola/internal/services/gateway/action" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" csclient "agola.io/agola/services/configstore/client" cstypes "agola.io/agola/services/configstore/types" @@ -53,12 +54,11 @@ func (h *UserRemoteReposHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques vars := mux.Vars(r) remoteSourceRef := vars["remotesourceref"] - userIDVal := ctx.Value("userid") - if userIDVal == nil { + userID := common.CurrentUserID(ctx) + if userID == "" { httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) return } - userID := userIDVal.(string) user, resp, err := h.configstoreClient.GetUser(ctx, userID) if httpErrorFromRemote(w, resp, err) { diff --git a/internal/services/gateway/api/user.go b/internal/services/gateway/api/user.go index 9016b93..e8c25d8 100644 --- a/internal/services/gateway/api/user.go +++ b/internal/services/gateway/api/user.go @@ -22,6 +22,7 @@ import ( "strconv" "agola.io/agola/internal/services/gateway/action" + "agola.io/agola/internal/services/gateway/common" "agola.io/agola/internal/util" csapitypes "agola.io/agola/services/configstore/api/types" cstypes "agola.io/agola/services/configstore/types" @@ -104,12 +105,11 @@ func NewCurrentUserHandler(logger *zap.Logger, ah *action.ActionHandler) *Curren func (h *CurrentUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userIDVal := ctx.Value("userid") - if userIDVal == nil { + userID := common.CurrentUserID(ctx) + if userID == "" { httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) return } - userID := userIDVal.(string) user, err := h.ah.GetUser(ctx, userID) if httpError(w, err) { @@ -604,14 +604,13 @@ func NewUserOrgsHandler(logger *zap.Logger, ah *action.ActionHandler) *UserOrgsH func (h *UserOrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userIDVal := ctx.Value("userid") - if userIDVal == nil { + userID := common.CurrentUserID(ctx) + if userID == "" { httpError(w, util.NewErrBadRequest(errors.Errorf("user not authenticated"))) return } - userRef := userIDVal.(string) - userOrgs, err := h.ah.GetUserOrgs(ctx, userRef) + userOrgs, err := h.ah.GetUserOrgs(ctx, userID) if httpError(w, err) { h.log.Errorf("err: %+v", err) return diff --git a/internal/services/gateway/common/common.go b/internal/services/gateway/common/common.go new file mode 100644 index 0000000..7933820 --- /dev/null +++ b/internal/services/gateway/common/common.go @@ -0,0 +1,36 @@ +package common + +import "context" + +type ContextKey int + +const ( + ContextKeyUserID ContextKey = iota + ContextKeyUsername + ContextKeyUserAdmin +) + +func CurrentUserID(ctx context.Context) string { + userIDVal := ctx.Value(ContextKeyUserID) + if userIDVal == nil { + return "" + } + return userIDVal.(string) +} + +func IsUserLogged(ctx context.Context) bool { + return ctx.Value(ContextKeyUserID) != nil +} + +func IsUserAdmin(ctx context.Context) bool { + isAdmin := false + isAdminVal := ctx.Value(ContextKeyUserAdmin) + if isAdminVal != nil { + isAdmin = isAdminVal.(bool) + } + return isAdmin +} + +func IsUserLoggedOrAdmin(ctx context.Context) bool { + return IsUserLogged(ctx) || IsUserAdmin(ctx) +} diff --git a/internal/services/gateway/handlers/auth.go b/internal/services/gateway/handlers/auth.go index 1d4c331..354e8a7 100644 --- a/internal/services/gateway/handlers/auth.go +++ b/internal/services/gateway/handlers/auth.go @@ -19,7 +19,8 @@ import ( "net/http" "strings" - "agola.io/agola/internal/services/common" + scommon "agola.io/agola/internal/services/common" + "agola.io/agola/internal/services/gateway/common" csclient "agola.io/agola/services/configstore/client" "github.com/golang-jwt/jwt/v4" @@ -35,12 +36,12 @@ type AuthHandler struct { configstoreClient *csclient.Client adminToken string - sd *common.TokenSigningData + sd *scommon.TokenSigningData required bool } -func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *common.TokenSigningData, required bool) func(http.Handler) http.Handler { +func NewAuthHandler(logger *zap.Logger, configstoreClient *csclient.Client, adminToken string, sd *scommon.TokenSigningData, required bool) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { return &AuthHandler{ log: logger.Sugar(), @@ -59,7 +60,7 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { tokenString, _ := TokenExtractor.ExtractToken(r) if h.adminToken != "" && tokenString != "" { if tokenString == h.adminToken { - ctx = context.WithValue(ctx, "admin", true) + ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true) h.next.ServeHTTP(w, r.WithContext(ctx)) return } else { @@ -74,11 +75,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // pass userid to handlers via context - ctx = context.WithValue(ctx, "userid", user.ID) - ctx = context.WithValue(ctx, "username", user.Name) + ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID) + ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name) if user.Admin { - ctx = context.WithValue(ctx, "admin", true) + ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true) } h.next.ServeHTTP(w, r.WithContext(ctx)) @@ -128,11 +129,11 @@ func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // pass userid and username to handlers via context - ctx = context.WithValue(ctx, "userid", user.ID) - ctx = context.WithValue(ctx, "username", user.Name) + ctx = context.WithValue(ctx, common.ContextKeyUserID, user.ID) + ctx = context.WithValue(ctx, common.ContextKeyUsername, user.Name) if user.Admin { - ctx = context.WithValue(ctx, "admin", true) + ctx = context.WithValue(ctx, common.ContextKeyUserAdmin, true) } h.next.ServeHTTP(w, r.WithContext(ctx))