diff --git a/internal/services/gateway/command/project.go b/internal/services/gateway/command/project.go index 337051c..6ce5455 100644 --- a/internal/services/gateway/command/project.go +++ b/internal/services/gateway/command/project.go @@ -19,7 +19,6 @@ import ( "fmt" "path" - "github.com/sorintlab/agola/internal/services/gateway/common" "github.com/sorintlab/agola/internal/services/types" "github.com/sorintlab/agola/internal/util" @@ -63,7 +62,7 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name) } - gitsource, err := common.GetGitSource(rs, la) + gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la) if err != nil { return nil, errors.Wrapf(err, "failed to create gitsource client") } @@ -105,11 +104,11 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe } c.log.Infof("project %s created, ID: %s", p.Name, p.ID) - return p, c.SetupProject(ctx, rs, la, p) + return p, c.SetupProject(ctx, rs, user, la, p) } -func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, la *types.LinkedAccount, project *types.Project) error { - gitsource, err := common.GetGitSource(rs, la) +func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, user *types.User, la *types.LinkedAccount, project *types.Project) error { + gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la) if err != nil { return errors.Wrapf(err, "failed to create gitsource client") } @@ -165,5 +164,5 @@ func (c *CommandHandler) ReconfigProject(ctx context.Context, projectRef string) // TODO(sgotti) update project repo path if the remote let us query by repository id - return c.SetupProject(ctx, rs, la, p) + return c.SetupProject(ctx, rs, user, la, p) } diff --git a/internal/services/gateway/command/user.go b/internal/services/gateway/command/user.go index d1a3e3a..483e701 100644 --- a/internal/services/gateway/command/user.go +++ b/internal/services/gateway/command/user.go @@ -29,6 +29,17 @@ import ( "github.com/pkg/errors" ) +const ( + expireTimeRange time.Duration = 5 * time.Minute +) + +func isAccessTokenExpired(expiresAt time.Time) bool { + if expiresAt.IsZero() { + return false + } + return expiresAt.Add(-expireTimeRange).Before(time.Now()) +} + type CreateUserRequest struct { UserName string } @@ -132,7 +143,7 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name)) } - accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } @@ -169,6 +180,81 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ return la, nil } +func (c *CommandHandler) UpdateUserLA(ctx context.Context, userName string, la *types.LinkedAccount) error { + user, resp, err := c.configstoreClient.GetUserByName(ctx, userName) + if err != nil { + return ErrFromRemote(resp, errors.Wrapf(err, "failed to get user %q", userName)) + } + laFound := false + for _, ula := range user.LinkedAccounts { + if ula.ID == la.ID { + laFound = true + break + } + } + c.log.Infof("la: %s", util.Dump(la)) + if !laFound { + return util.NewErrBadRequest(errors.Errorf("user %q doesn't have a linked account with id %q", userName, la.ID)) + } + + creq := &csapi.UpdateUserLARequest{ + RemoteUserID: la.RemoteUserID, + RemoteUserName: la.RemoteUserName, + UserAccessToken: la.UserAccessToken, + Oauth2AccessToken: la.Oauth2AccessToken, + Oauth2RefreshToken: la.Oauth2RefreshToken, + Oauth2AccessTokenExpiresAt: la.Oauth2AccessTokenExpiresAt, + } + + c.log.Infof("updating user %q linked account", userName) + la, resp, err = c.configstoreClient.UpdateUserLA(ctx, userName, la.ID, creq) + if err != nil { + return ErrFromRemote(resp, errors.Wrapf(err, "failed to update user")) + } + c.log.Infof("linked account %q for user %q updated", la.ID, userName) + + return nil +} + +// RefreshLinkedAccount refreshed the linked account oauth2 access token and update linked account in the configstore +func (c *CommandHandler) RefreshLinkedAccount(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (*types.LinkedAccount, error) { + switch rs.AuthType { + case types.RemoteSourceAuthTypeOauth2: + // refresh access token if expired + if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) { + userSource, err := common.GetOauth2Source(rs, "") + if err != nil { + return nil, err + } + token, err := userSource.RefreshOauth2Token(la.Oauth2RefreshToken) + if err != nil { + return nil, err + } + + if la.Oauth2AccessToken != token.AccessToken { + la.Oauth2AccessToken = token.AccessToken + la.Oauth2RefreshToken = token.RefreshToken + la.Oauth2AccessTokenExpiresAt = token.Expiry + + if err := c.UpdateUserLA(ctx, userName, la); err != nil { + return nil, errors.Wrapf(err, "failed to update linked account") + } + } + } + } + return la, nil +} + +// GetGitSource is a wrapper around common.GetGitSource that will also refresh +// the oauth2 access token and update the linked account when needed +func (c *CommandHandler) GetGitSource(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (gitsource.GitSource, error) { + la, err := c.RefreshLinkedAccount(ctx, rs, userName, la) + if err != nil { + return nil, err + } + return common.GetGitSource(rs, la) +} + type RegisterUserRequest struct { UserName string RemoteSourceName string @@ -192,7 +278,7 @@ func (c *CommandHandler) RegisterUser(ctx context.Context, req *RegisterUserRequ } c.log.Infof("rs: %s", util.Dump(rs)) - accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } @@ -252,7 +338,7 @@ func (c *CommandHandler) LoginUser(ctx context.Context, req *LoginUserRequest) ( } c.log.Infof("rs: %s", util.Dump(rs)) - accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } @@ -343,7 +429,7 @@ func (c *CommandHandler) Authorize(ctx context.Context, req *AuthorizeRequest) ( } c.log.Infof("rs: %s", util.Dump(rs)) - accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) + accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken) if err != nil { return nil, err } diff --git a/internal/services/gateway/common/gitsource.go b/internal/services/gateway/common/gitsource.go index f334cd7..0026970 100644 --- a/internal/services/gateway/common/gitsource.go +++ b/internal/services/gateway/common/gitsource.go @@ -17,11 +17,12 @@ package common import ( "fmt" - "github.com/pkg/errors" gitsource "github.com/sorintlab/agola/internal/gitsources" "github.com/sorintlab/agola/internal/gitsources/gitea" "github.com/sorintlab/agola/internal/gitsources/gitlab" "github.com/sorintlab/agola/internal/services/types" + + "github.com/pkg/errors" ) func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType { @@ -66,14 +67,14 @@ func newGitlab(rs *types.RemoteSource, accessToken string) (*gitlab.Client, erro }) } -func GetAccessToken(authType types.RemoteSourceAuthType, userAccessToken, oauth2AccessToken string) (string, error) { - switch authType { +func GetAccessToken(rs *types.RemoteSource, userAccessToken, oauth2AccessToken string) (string, error) { + switch rs.AuthType { case types.RemoteSourceAuthTypePassword: return userAccessToken, nil case types.RemoteSourceAuthTypeOauth2: return oauth2AccessToken, nil default: - return "", errors.Errorf("invalid remote source auth type %q", authType) + return "", errors.Errorf("invalid remote source auth type %q", rs.AuthType) } } @@ -81,7 +82,7 @@ func GetGitSource(rs *types.RemoteSource, la *types.LinkedAccount) (gitsource.Gi var accessToken string if la != nil { var err error - accessToken, err = GetAccessToken(rs.AuthType, la.UserAccessToken, la.Oauth2AccessToken) + accessToken, err = GetAccessToken(rs, la.UserAccessToken, la.Oauth2AccessToken) if err != nil { return nil, err } diff --git a/internal/services/gateway/gateway.go b/internal/services/gateway/gateway.go index f7c3140..c10b694 100644 --- a/internal/services/gateway/gateway.go +++ b/internal/services/gateway/gateway.go @@ -145,7 +145,7 @@ func (g *Gateway) Run(ctx context.Context) error { corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"}) corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) - webhooksHandler := &webhooksHandler{log: log, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL} + webhooksHandler := &webhooksHandler{log: log, ch: g.ch, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL} projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient) projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient) diff --git a/internal/services/gateway/webhook.go b/internal/services/gateway/webhook.go index 048d145..5a7a851 100644 --- a/internal/services/gateway/webhook.go +++ b/internal/services/gateway/webhook.go @@ -26,6 +26,7 @@ import ( "github.com/sorintlab/agola/internal/gitsources/agolagit" "github.com/sorintlab/agola/internal/runconfig" csapi "github.com/sorintlab/agola/internal/services/configstore/api" + "github.com/sorintlab/agola/internal/services/gateway/command" "github.com/sorintlab/agola/internal/services/gateway/common" rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api" rstypes "github.com/sorintlab/agola/internal/services/runservice/types" @@ -107,6 +108,7 @@ func genGroup(baseGroupType GroupType, baseGroupID string, webhookData *types.We type webhooksHandler struct { log *zap.SugaredLogger + ch *command.CommandHandler configstoreClient *csapi.Client runserviceClient *rsapi.Client apiExposedURL string @@ -165,7 +167,7 @@ func (h *webhooksHandler) handleWebhook(r *http.Request) (int, string, error) { return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID) } - gitSource, err = common.GetGitSource(rs, la) + gitSource, err = h.ch.GetGitSource(ctx, rs, user.UserName, la) if err != nil { return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client") }