gateway: refresh oauth2 token when needed
refresh the oauth2 access token when needed and update the related linked account.
This commit is contained in:
parent
9393830207
commit
cbd971619b
|
@ -19,7 +19,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/sorintlab/agola/internal/services/gateway/common"
|
|
||||||
"github.com/sorintlab/agola/internal/services/types"
|
"github.com/sorintlab/agola/internal/services/types"
|
||||||
"github.com/sorintlab/agola/internal/util"
|
"github.com/sorintlab/agola/internal/util"
|
||||||
|
|
||||||
|
@ -63,7 +62,7 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
|
||||||
return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name)
|
return nil, errors.Errorf("user doesn't have a linked account for remote source %q", rs.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
gitsource, err := common.GetGitSource(rs, la)
|
gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "failed to create gitsource client")
|
return nil, errors.Wrapf(err, "failed to create gitsource client")
|
||||||
}
|
}
|
||||||
|
@ -105,11 +104,11 @@ func (c *CommandHandler) CreateProject(ctx context.Context, req *CreateProjectRe
|
||||||
}
|
}
|
||||||
c.log.Infof("project %s created, ID: %s", p.Name, p.ID)
|
c.log.Infof("project %s created, ID: %s", p.Name, p.ID)
|
||||||
|
|
||||||
return p, c.SetupProject(ctx, rs, la, p)
|
return p, c.SetupProject(ctx, rs, user, la, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, la *types.LinkedAccount, project *types.Project) error {
|
func (c *CommandHandler) SetupProject(ctx context.Context, rs *types.RemoteSource, user *types.User, la *types.LinkedAccount, project *types.Project) error {
|
||||||
gitsource, err := common.GetGitSource(rs, la)
|
gitsource, err := c.GetGitSource(ctx, rs, user.UserName, la)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "failed to create gitsource client")
|
return errors.Wrapf(err, "failed to create gitsource client")
|
||||||
}
|
}
|
||||||
|
@ -165,5 +164,5 @@ func (c *CommandHandler) ReconfigProject(ctx context.Context, projectRef string)
|
||||||
|
|
||||||
// TODO(sgotti) update project repo path if the remote let us query by repository id
|
// TODO(sgotti) update project repo path if the remote let us query by repository id
|
||||||
|
|
||||||
return c.SetupProject(ctx, rs, la, p)
|
return c.SetupProject(ctx, rs, user, la, p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,17 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
expireTimeRange time.Duration = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
func isAccessTokenExpired(expiresAt time.Time) bool {
|
||||||
|
if expiresAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return expiresAt.Add(-expireTimeRange).Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
type CreateUserRequest struct {
|
type CreateUserRequest struct {
|
||||||
UserName string
|
UserName string
|
||||||
}
|
}
|
||||||
|
@ -132,7 +143,7 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
|
||||||
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name))
|
return nil, util.NewErrBadRequest(errors.Errorf("user %q already have a linked account for remote source %q", userName, rs.Name))
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -169,6 +180,81 @@ func (c *CommandHandler) CreateUserLA(ctx context.Context, req *CreateUserLARequ
|
||||||
return la, nil
|
return la, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CommandHandler) UpdateUserLA(ctx context.Context, userName string, la *types.LinkedAccount) error {
|
||||||
|
user, resp, err := c.configstoreClient.GetUserByName(ctx, userName)
|
||||||
|
if err != nil {
|
||||||
|
return ErrFromRemote(resp, errors.Wrapf(err, "failed to get user %q", userName))
|
||||||
|
}
|
||||||
|
laFound := false
|
||||||
|
for _, ula := range user.LinkedAccounts {
|
||||||
|
if ula.ID == la.ID {
|
||||||
|
laFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.log.Infof("la: %s", util.Dump(la))
|
||||||
|
if !laFound {
|
||||||
|
return util.NewErrBadRequest(errors.Errorf("user %q doesn't have a linked account with id %q", userName, la.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
creq := &csapi.UpdateUserLARequest{
|
||||||
|
RemoteUserID: la.RemoteUserID,
|
||||||
|
RemoteUserName: la.RemoteUserName,
|
||||||
|
UserAccessToken: la.UserAccessToken,
|
||||||
|
Oauth2AccessToken: la.Oauth2AccessToken,
|
||||||
|
Oauth2RefreshToken: la.Oauth2RefreshToken,
|
||||||
|
Oauth2AccessTokenExpiresAt: la.Oauth2AccessTokenExpiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Infof("updating user %q linked account", userName)
|
||||||
|
la, resp, err = c.configstoreClient.UpdateUserLA(ctx, userName, la.ID, creq)
|
||||||
|
if err != nil {
|
||||||
|
return ErrFromRemote(resp, errors.Wrapf(err, "failed to update user"))
|
||||||
|
}
|
||||||
|
c.log.Infof("linked account %q for user %q updated", la.ID, userName)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLinkedAccount refreshed the linked account oauth2 access token and update linked account in the configstore
|
||||||
|
func (c *CommandHandler) RefreshLinkedAccount(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (*types.LinkedAccount, error) {
|
||||||
|
switch rs.AuthType {
|
||||||
|
case types.RemoteSourceAuthTypeOauth2:
|
||||||
|
// refresh access token if expired
|
||||||
|
if isAccessTokenExpired(la.Oauth2AccessTokenExpiresAt) {
|
||||||
|
userSource, err := common.GetOauth2Source(rs, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token, err := userSource.RefreshOauth2Token(la.Oauth2RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if la.Oauth2AccessToken != token.AccessToken {
|
||||||
|
la.Oauth2AccessToken = token.AccessToken
|
||||||
|
la.Oauth2RefreshToken = token.RefreshToken
|
||||||
|
la.Oauth2AccessTokenExpiresAt = token.Expiry
|
||||||
|
|
||||||
|
if err := c.UpdateUserLA(ctx, userName, la); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "failed to update linked account")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return la, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGitSource is a wrapper around common.GetGitSource that will also refresh
|
||||||
|
// the oauth2 access token and update the linked account when needed
|
||||||
|
func (c *CommandHandler) GetGitSource(ctx context.Context, rs *types.RemoteSource, userName string, la *types.LinkedAccount) (gitsource.GitSource, error) {
|
||||||
|
la, err := c.RefreshLinkedAccount(ctx, rs, userName, la)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return common.GetGitSource(rs, la)
|
||||||
|
}
|
||||||
|
|
||||||
type RegisterUserRequest struct {
|
type RegisterUserRequest struct {
|
||||||
UserName string
|
UserName string
|
||||||
RemoteSourceName string
|
RemoteSourceName string
|
||||||
|
@ -192,7 +278,7 @@ func (c *CommandHandler) RegisterUser(ctx context.Context, req *RegisterUserRequ
|
||||||
}
|
}
|
||||||
c.log.Infof("rs: %s", util.Dump(rs))
|
c.log.Infof("rs: %s", util.Dump(rs))
|
||||||
|
|
||||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -252,7 +338,7 @@ func (c *CommandHandler) LoginUser(ctx context.Context, req *LoginUserRequest) (
|
||||||
}
|
}
|
||||||
c.log.Infof("rs: %s", util.Dump(rs))
|
c.log.Infof("rs: %s", util.Dump(rs))
|
||||||
|
|
||||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -343,7 +429,7 @@ func (c *CommandHandler) Authorize(ctx context.Context, req *AuthorizeRequest) (
|
||||||
}
|
}
|
||||||
c.log.Infof("rs: %s", util.Dump(rs))
|
c.log.Infof("rs: %s", util.Dump(rs))
|
||||||
|
|
||||||
accessToken, err := common.GetAccessToken(rs.AuthType, req.UserAccessToken, req.Oauth2AccessToken)
|
accessToken, err := common.GetAccessToken(rs, req.UserAccessToken, req.Oauth2AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,12 @@ package common
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
gitsource "github.com/sorintlab/agola/internal/gitsources"
|
gitsource "github.com/sorintlab/agola/internal/gitsources"
|
||||||
"github.com/sorintlab/agola/internal/gitsources/gitea"
|
"github.com/sorintlab/agola/internal/gitsources/gitea"
|
||||||
"github.com/sorintlab/agola/internal/gitsources/gitlab"
|
"github.com/sorintlab/agola/internal/gitsources/gitlab"
|
||||||
"github.com/sorintlab/agola/internal/services/types"
|
"github.com/sorintlab/agola/internal/services/types"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType {
|
func SourceSupportedAuthTypes(rsType types.RemoteSourceType) []types.RemoteSourceAuthType {
|
||||||
|
@ -66,14 +67,14 @@ func newGitlab(rs *types.RemoteSource, accessToken string) (*gitlab.Client, erro
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccessToken(authType types.RemoteSourceAuthType, userAccessToken, oauth2AccessToken string) (string, error) {
|
func GetAccessToken(rs *types.RemoteSource, userAccessToken, oauth2AccessToken string) (string, error) {
|
||||||
switch authType {
|
switch rs.AuthType {
|
||||||
case types.RemoteSourceAuthTypePassword:
|
case types.RemoteSourceAuthTypePassword:
|
||||||
return userAccessToken, nil
|
return userAccessToken, nil
|
||||||
case types.RemoteSourceAuthTypeOauth2:
|
case types.RemoteSourceAuthTypeOauth2:
|
||||||
return oauth2AccessToken, nil
|
return oauth2AccessToken, nil
|
||||||
default:
|
default:
|
||||||
return "", errors.Errorf("invalid remote source auth type %q", authType)
|
return "", errors.Errorf("invalid remote source auth type %q", rs.AuthType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ func GetGitSource(rs *types.RemoteSource, la *types.LinkedAccount) (gitsource.Gi
|
||||||
var accessToken string
|
var accessToken string
|
||||||
if la != nil {
|
if la != nil {
|
||||||
var err error
|
var err error
|
||||||
accessToken, err = GetAccessToken(rs.AuthType, la.UserAccessToken, la.Oauth2AccessToken)
|
accessToken, err = GetAccessToken(rs, la.UserAccessToken, la.Oauth2AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,7 +145,7 @@ func (g *Gateway) Run(ctx context.Context) error {
|
||||||
corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"})
|
corsAllowedOriginsOptions := ghandlers.AllowedOrigins([]string{"*"})
|
||||||
corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions)
|
corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions)
|
||||||
|
|
||||||
webhooksHandler := &webhooksHandler{log: log, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL}
|
webhooksHandler := &webhooksHandler{log: log, ch: g.ch, configstoreClient: g.configstoreClient, runserviceClient: g.runserviceClient, apiExposedURL: g.c.APIExposedURL}
|
||||||
|
|
||||||
projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient)
|
projectGroupHandler := api.NewProjectGroupHandler(logger, g.configstoreClient)
|
||||||
projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient)
|
projectGroupSubgroupsHandler := api.NewProjectGroupSubgroupsHandler(logger, g.configstoreClient)
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/sorintlab/agola/internal/gitsources/agolagit"
|
"github.com/sorintlab/agola/internal/gitsources/agolagit"
|
||||||
"github.com/sorintlab/agola/internal/runconfig"
|
"github.com/sorintlab/agola/internal/runconfig"
|
||||||
csapi "github.com/sorintlab/agola/internal/services/configstore/api"
|
csapi "github.com/sorintlab/agola/internal/services/configstore/api"
|
||||||
|
"github.com/sorintlab/agola/internal/services/gateway/command"
|
||||||
"github.com/sorintlab/agola/internal/services/gateway/common"
|
"github.com/sorintlab/agola/internal/services/gateway/common"
|
||||||
rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api"
|
rsapi "github.com/sorintlab/agola/internal/services/runservice/scheduler/api"
|
||||||
rstypes "github.com/sorintlab/agola/internal/services/runservice/types"
|
rstypes "github.com/sorintlab/agola/internal/services/runservice/types"
|
||||||
|
@ -107,6 +108,7 @@ func genGroup(baseGroupType GroupType, baseGroupID string, webhookData *types.We
|
||||||
|
|
||||||
type webhooksHandler struct {
|
type webhooksHandler struct {
|
||||||
log *zap.SugaredLogger
|
log *zap.SugaredLogger
|
||||||
|
ch *command.CommandHandler
|
||||||
configstoreClient *csapi.Client
|
configstoreClient *csapi.Client
|
||||||
runserviceClient *rsapi.Client
|
runserviceClient *rsapi.Client
|
||||||
apiExposedURL string
|
apiExposedURL string
|
||||||
|
@ -165,7 +167,7 @@ func (h *webhooksHandler) handleWebhook(r *http.Request) (int, string, error) {
|
||||||
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID)
|
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to get remote source %q", la.RemoteSourceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
gitSource, err = common.GetGitSource(rs, la)
|
gitSource, err = h.ch.GetGitSource(ctx, rs, user.UserName, la)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client")
|
return http.StatusInternalServerError, "", errors.Wrapf(err, "failed to create gitea client")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue