gateway: refresh oauth2 token when needed
refresh the oauth2 access token when needed and update the related linked account.
This commit is contained in:
parent
9393830207
commit
cbd971619b
|
@ -19,7 +19,6 @@ import (
|
|||
"fmt"
|
||||
"path"
|
||||
|
||||
"github.com/sorintlab/agola/internal/services/gateway/common"
|
||||
"github.com/sorintlab/agola/internal/services/types"
|
||||
"github.com/sorintlab/agola/internal/util"
|
||||
|
||||
|
@ -63,7 +62,7 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
|
|||
return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name)
|
||||
}
|
||||
|
||||
gitsource, err := common.GetGitSource(rs, la)
|
||||
gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to create gitsource client")
|
||||
}
|
||||
|
@ -105,11 +104,11 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
|
|||
}
|
||||
c.log.Infof("project %s created, ID: %s", p.Name, p.ID)
|
||||
|
||||
return p, c.SetupProject(ctx, rs, la, p)
|
||||
return p, c.SetupProject(ctx, rs, user, la, p)
|
||||
}
|
||||
|
||||
func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, la *types.LinkedAccount, project *types.Project) error {
|
||||
gitsource, err := common.GetGitSource(rs, la)
|
||||
func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, user *types.User, la *types.LinkedAccount, project *types.Project) error {
|
||||
gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create gitsource client")
|
||||
}
|
||||
|
@ -165,5 +164,5 @@ func (c *CommandHandler) ReconfigProject(ctx context.Context, projectRef string)
|
|||
|
||||
// TODO(sgotti) update project repo path if the remote let us query by repository id
|
||||
|
||||
return c.SetupProject(ctx, rs, la, p)
|
||||
return c.SetupProject(ctx, rs, user, la, p)
|
||||
}
|
||||
|
|
|
@ -29,6 +29,17 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
expireTimeRange time.Duration = 5 * time.Minute
|
||||
)
|
||||
|
||||
func isAccessTokenExpired(expiresAt time.Time) bool {
|
||||
if expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return expiresAt.Add(-expireTimeRange).Before(time.Now())
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
UserName string
|
||||
}
|
||||
|
@ -132,7 +143,7 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
|
|||
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name))
|
||||
}
|
||||
|
||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -169,6 +180,81 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
|
|||
return la, nil
|
||||
}
|
||||
|
||||
func (c *CommandHandler) UpdateUserLA(ctx context.Context, userName string, la *types.LinkedAccount) error {
|
||||
user, resp, err := c.configstoreClient.GetUserByName(ctx, userName)
|
||||
if err != nil {
|
||||
return ErrFromRemote(resp, errors.Wrapf(err, "failed to get user %q", userName))
|
||||
}
|
||||
laFound := false
|
||||
for _, ula := range user.LinkedAccounts {
|
||||
if ula.ID == la.ID {
|
||||
laFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
c.log.Infof("la: %s", util.Dump(la))
|
||||
if !laFound {
|
||||
return util.NewErrBadRequest(errors.Errorf("user %q doesn't have a linked account with id %q", userName, la.ID))
|
||||
}
|
||||
|
||||
creq := &csapi.UpdateUserLARequest{
|
||||
RemoteUserID: la.RemoteUserID,
|
||||
RemoteUserName: la.RemoteUserName,
|
||||
UserAccessToken: la.UserAccessToken,
|
||||
Oauth2AccessToken: la.Oauth2AccessToken,
|
||||
Oauth2RefreshToken: la.Oauth2RefreshToken,
|
||||
Oauth2AccessTokenExpiresAt: la.Oauth2AccessTokenExpiresAt,
|
||||
}
|
||||
|
||||
c.log.Infof("updating user %q linked account", userName)
|
||||
la, resp, err = c.configstoreClient.UpdateUserLA(ctx, userName, la.ID, creq)
|
||||
if err != nil {
|
||||
return ErrFromRemote(resp, errors.Wrapf(err, "failed to update user"))
|
||||
}
|
||||
c.log.Infof("linked account %q for user %q updated", la.ID, userName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshLinkedAccount refreshed the linked account oauth2 access token and update linked account in the configstore
|
||||
func (c *CommandHandler) RefreshLinkedAccount(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (*types.LinkedAccount, error) {
|
||||
switch rs.AuthType {
|
||||
case types.RemoteSourceAuthTypeOauth2:
|
||||
// refresh access token if expired
|
||||
if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) {
|
||||
userSource, err := common.GetOauth2Source(rs, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := userSource.RefreshOauth2Token(la.Oauth2RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if la.Oauth2AccessToken != token.AccessToken {
|
||||
la.Oauth2AccessToken = token.AccessToken
|
||||
la.Oauth2RefreshToken = token.RefreshToken
|
||||
la.Oauth2AccessTokenExpiresAt = token.Expiry
|
||||
|
||||
if err := c.UpdateUserLA(ctx, userName, la); err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to update linked account")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return la, nil
|
||||
}
|
||||
|
||||
// GetGitSource is a wrapper around common.GetGitSource that will also refresh
|
||||
// the oauth2 access token and update the linked account when needed
|
||||
func (c *CommandHandler) GetGitSource(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (gitsource.GitSource, error) {
|
||||
la, err := c.RefreshLinkedAccount(ctx, rs, userName, la)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return common.GetGitSource(rs, la)
|
||||
}
|
||||
|
||||
type RegisterUserRequest struct {
|
||||
UserName string
|
||||
RemoteSourceName string
|
||||
|
@ -192,7 +278,7 @@ func (c *CommandHandler) RegisterUser(ctx context.Context, req *RegisterUserRequ
|
|||
}
|
||||
c.log.Infof("rs: %s", util.Dump(rs))
|
||||
|
||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -252,7 +338,7 @@ func (c *CommandHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (
|
|||
}
|
||||
c.log.Infof("rs: %s", util.Dump(rs))
|
||||
|
||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -343,7 +429,7 @@ func (c *CommandHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (
|
|||
}
|
||||
c.log.Infof("rs: %s", util.Dump(rs))
|
||||
|
||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -17,11 +17,12 @@ package common
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
gitsource "github.com/sorintlab/agola/internal/gitsources"
|
||||
"github.com/sorintlab/agola/internal/gitsources/gitea"
|
||||
"github.com/sorintlab/agola/internal/gitsources/gitlab"
|
||||
"github.com/sorintlab/agola/internal/services/types"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType {
|
||||
|
@ -66,14 +67,14 @@ func newGitlab(rs *types.RemoteSource, accessToken string) (*gitlab.Client, erro
|
|||
})
|
||||
}
|
||||
|
||||
func GetAccessToken(authType types.RemoteSourceAuthType, userAccessToken, oauth2AccessToken string) (string, error) {
|
||||
switch authType {
|
||||
func GetAccessToken(rs *types.RemoteSource, userAccessToken, oauth2AccessToken string) (string, error) {
|
||||
switch rs.AuthType {
|
||||
case types.RemoteSourceAuthTypePassword:
|
||||
return userAccessToken, nil
|
||||
case types.RemoteSourceAuthTypeOauth2:
|
||||
return oauth2AccessToken, nil
|
||||
default:
|
||||
return "", errors.Errorf("invalid remote source auth type %q", authType)
|
||||
return "", errors.Errorf("invalid remote source auth type %q", rs.AuthType)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,7 +82,7 @@ func GetGitSource(rs *types.RemoteSource, la *types.LinkedAccount) (gitsource.Gi
|
|||
var accessToken string
|
||||
if la != nil {
|
||||
var err error
|
||||
accessToken, err = GetAccessToken(rs.AuthType, la.UserAccessToken, la.Oauth2AccessToken)
|
||||
accessToken, err = GetAccessToken(rs, la.UserAccessToken, la.Oauth2AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ func (g *Gateway) Run(ctx context.Context) error {
|
|||
corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"})
|
||||
corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions)
|
||||
|
||||
webhooksHandler := &webhooksHandler{log: log, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL}
|
||||
webhooksHandler := &webhooksHandler{log: log, ch: g.ch, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL}
|
||||
|
||||
projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient)
|
||||
projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient)
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/sorintlab/agola/internal/gitsources/agolagit"
|
||||
"github.com/sorintlab/agola/internal/runconfig"
|
||||
csapi "github.com/sorintlab/agola/internal/services/configstore/api"
|
||||
"github.com/sorintlab/agola/internal/services/gateway/command"
|
||||
"github.com/sorintlab/agola/internal/services/gateway/common"
|
||||
rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api"
|
||||
rstypes "github.com/sorintlab/agola/internal/services/runservice/types"
|
||||
|
@ -107,6 +108,7 @@ func genGroup(baseGroupType GroupType, baseGroupID string, webhookData *types.We
|
|||
|
||||
type webhooksHandler struct {
|
||||
log *zap.SugaredLogger
|
||||
ch *command.CommandHandler
|
||||
configstoreClient *csapi.Client
|
||||
runserviceClient *rsapi.Client
|
||||
apiExposedURL string
|
||||
|
@ -165,7 +167,7 @@ func (h *webhooksHandler) handleWebhook(r *http.Request) (int, string, error) {
|
|||
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID)
|
||||
}
|
||||
|
||||
gitSource, err = common.GetGitSource(rs, la)
|
||||
gitSource, err = h.ch.GetGitSource(ctx, rs, user.UserName, la)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue