gateway: refresh oauth2 token when needed

refresh the oauth2 access token when needed and update the related linked
account.
This commit is contained in:
Simone Gotti 2019-04-29 15:42:10 +02:00
parent 9393830207
commit cbd971619b
5 changed files with 105 additions and 17 deletions

View File

@ -19,7 +19,6 @@ import (
"fmt" "fmt"
"path" "path"
"github.com/sorintlab/agola/internal/services/gateway/common"
"github.com/sorintlab/agola/internal/services/types" "github.com/sorintlab/agola/internal/services/types"
"github.com/sorintlab/agola/internal/util" "github.com/sorintlab/agola/internal/util"
@ -63,7 +62,7 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name) return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name)
} }
gitsource, err := common.GetGitSource(rs, la) gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed to create gitsource client") return nil, errors.Wrapf(err, "failed to create gitsource client")
} }
@ -105,11 +104,11 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
} }
c.log.Infof("project %s created, ID: %s", p.Name, p.ID) c.log.Infof("project %s created, ID: %s", p.Name, p.ID)
return p, c.SetupProject(ctx, rs, la, p) return p, c.SetupProject(ctx, rs, user, la, p)
} }
func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, la *types.LinkedAccount, project *types.Project) error { func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, user *types.User, la *types.LinkedAccount, project *types.Project) error {
gitsource, err := common.GetGitSource(rs, la) gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to create gitsource client") return errors.Wrapf(err, "failed to create gitsource client")
} }
@ -165,5 +164,5 @@ func (c *CommandHandler) ReconfigProject(ctx context.Context, projectRef string)
// TODO(sgotti) update project repo path if the remote let us query by repository id // TODO(sgotti) update project repo path if the remote let us query by repository id
return c.SetupProject(ctx, rs, la, p) return c.SetupProject(ctx, rs, user, la, p)
} }

View File

@ -29,6 +29,17 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const (
expireTimeRange time.Duration = 5 * time.Minute
)
func isAccessTokenExpired(expiresAt time.Time) bool {
if expiresAt.IsZero() {
return false
}
return expiresAt.Add(-expireTimeRange).Before(time.Now())
}
type CreateUserRequest struct { type CreateUserRequest struct {
UserName string UserName string
} }
@ -132,7 +143,7 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name)) return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name))
} }
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -169,6 +180,81 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
return la, nil return la, nil
} }
func (c *CommandHandler) UpdateUserLA(ctx context.Context, userName string, la *types.LinkedAccount) error {
user, resp, err := c.configstoreClient.GetUserByName(ctx, userName)
if err != nil {
return ErrFromRemote(resp, errors.Wrapf(err, "failed to get user %q", userName))
}
laFound := false
for _, ula := range user.LinkedAccounts {
if ula.ID == la.ID {
laFound = true
break
}
}
c.log.Infof("la: %s", util.Dump(la))
if !laFound {
return util.NewErrBadRequest(errors.Errorf("user %q doesn't have a linked account with id %q", userName, la.ID))
}
creq := &csapi.UpdateUserLARequest{
RemoteUserID: la.RemoteUserID,
RemoteUserName: la.RemoteUserName,
UserAccessToken: la.UserAccessToken,
Oauth2AccessToken: la.Oauth2AccessToken,
Oauth2RefreshToken: la.Oauth2RefreshToken,
Oauth2AccessTokenExpiresAt: la.Oauth2AccessTokenExpiresAt,
}
c.log.Infof("updating user %q linked account", userName)
la, resp, err = c.configstoreClient.UpdateUserLA(ctx, userName, la.ID, creq)
if err != nil {
return ErrFromRemote(resp, errors.Wrapf(err, "failed to update user"))
}
c.log.Infof("linked account %q for user %q updated", la.ID, userName)
return nil
}
// RefreshLinkedAccount refreshed the linked account oauth2 access token and update linked account in the configstore
func (c *CommandHandler) RefreshLinkedAccount(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (*types.LinkedAccount, error) {
switch rs.AuthType {
case types.RemoteSourceAuthTypeOauth2:
// refresh access token if expired
if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) {
userSource, err := common.GetOauth2Source(rs, "")
if err != nil {
return nil, err
}
token, err := userSource.RefreshOauth2Token(la.Oauth2RefreshToken)
if err != nil {
return nil, err
}
if la.Oauth2AccessToken != token.AccessToken {
la.Oauth2AccessToken = token.AccessToken
la.Oauth2RefreshToken = token.RefreshToken
la.Oauth2AccessTokenExpiresAt = token.Expiry
if err := c.UpdateUserLA(ctx, userName, la); err != nil {
return nil, errors.Wrapf(err, "failed to update linked account")
}
}
}
}
return la, nil
}
// GetGitSource is a wrapper around common.GetGitSource that will also refresh
// the oauth2 access token and update the linked account when needed
func (c *CommandHandler) GetGitSource(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (gitsource.GitSource, error) {
la, err := c.RefreshLinkedAccount(ctx, rs, userName, la)
if err != nil {
return nil, err
}
return common.GetGitSource(rs, la)
}
type RegisterUserRequest struct { type RegisterUserRequest struct {
UserName string UserName string
RemoteSourceName string RemoteSourceName string
@ -192,7 +278,7 @@ func (c *CommandHandler) RegisterUser(ctx context.Context, req *RegisterUserRequ
} }
c.log.Infof("rs: %s", util.Dump(rs)) c.log.Infof("rs: %s", util.Dump(rs))
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -252,7 +338,7 @@ func (c *CommandHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (
} }
c.log.Infof("rs: %s", util.Dump(rs)) c.log.Infof("rs: %s", util.Dump(rs))
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -343,7 +429,7 @@ func (c *CommandHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (
} }
c.log.Infof("rs: %s", util.Dump(rs)) c.log.Infof("rs: %s", util.Dump(rs))
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken) accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -17,11 +17,12 @@ package common
import ( import (
"fmt" "fmt"
"github.com/pkg/errors"
gitsource "github.com/sorintlab/agola/internal/gitsources" gitsource "github.com/sorintlab/agola/internal/gitsources"
"github.com/sorintlab/agola/internal/gitsources/gitea" "github.com/sorintlab/agola/internal/gitsources/gitea"
"github.com/sorintlab/agola/internal/gitsources/gitlab" "github.com/sorintlab/agola/internal/gitsources/gitlab"
"github.com/sorintlab/agola/internal/services/types" "github.com/sorintlab/agola/internal/services/types"
"github.com/pkg/errors"
) )
func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType { func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType {
@ -66,14 +67,14 @@ func newGitlab(rs *types.RemoteSource, accessToken string) (*gitlab.Client, erro
}) })
} }
func GetAccessToken(authType types.RemoteSourceAuthType, userAccessToken, oauth2AccessToken string) (string, error) { func GetAccessToken(rs *types.RemoteSource, userAccessToken, oauth2AccessToken string) (string, error) {
switch authType { switch rs.AuthType {
case types.RemoteSourceAuthTypePassword: case types.RemoteSourceAuthTypePassword:
return userAccessToken, nil return userAccessToken, nil
case types.RemoteSourceAuthTypeOauth2: case types.RemoteSourceAuthTypeOauth2:
return oauth2AccessToken, nil return oauth2AccessToken, nil
default: default:
return "", errors.Errorf("invalid remote source auth type %q", authType) return "", errors.Errorf("invalid remote source auth type %q", rs.AuthType)
} }
} }
@ -81,7 +82,7 @@ func GetGitSource(rs *types.RemoteSource, la *types.LinkedAccount) (gitsource.Gi
var accessToken string var accessToken string
if la != nil { if la != nil {
var err error var err error
accessToken, err = GetAccessToken(rs.AuthType, la.UserAccessToken, la.Oauth2AccessToken) accessToken, err = GetAccessToken(rs, la.UserAccessToken, la.Oauth2AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -145,7 +145,7 @@ func (g *Gateway) Run(ctx context.Context) error {
corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"}) corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"})
corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions)
webhooksHandler := &webhooksHandler{log: log, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL} webhooksHandler := &webhooksHandler{log: log, ch: g.ch, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL}
projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient) projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient)
projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient) projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient)

View File

@ -26,6 +26,7 @@ import (
"github.com/sorintlab/agola/internal/gitsources/agolagit" "github.com/sorintlab/agola/internal/gitsources/agolagit"
"github.com/sorintlab/agola/internal/runconfig" "github.com/sorintlab/agola/internal/runconfig"
csapi "github.com/sorintlab/agola/internal/services/configstore/api" csapi "github.com/sorintlab/agola/internal/services/configstore/api"
"github.com/sorintlab/agola/internal/services/gateway/command"
"github.com/sorintlab/agola/internal/services/gateway/common" "github.com/sorintlab/agola/internal/services/gateway/common"
rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api" rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api"
rstypes "github.com/sorintlab/agola/internal/services/runservice/types" rstypes "github.com/sorintlab/agola/internal/services/runservice/types"
@ -107,6 +108,7 @@ func genGroup(baseGroupType GroupType, baseGroupID string, webhookData *types.We
type webhooksHandler struct { type webhooksHandler struct {
log *zap.SugaredLogger log *zap.SugaredLogger
ch *command.CommandHandler
configstoreClient *csapi.Client configstoreClient *csapi.Client
runserviceClient *rsapi.Client runserviceClient *rsapi.Client
apiExposedURL string apiExposedURL string
@ -165,7 +167,7 @@ func (h *webhooksHandler) handleWebhook(r *http.Request) (int, string, error) {
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID) return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID)
} }
gitSource, err = common.GetGitSource(rs, la) gitSource, err = h.ch.GetGitSource(ctx, rs, user.UserName, la)
if err != nil { if err != nil {
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client") return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client")
} }