From b8c2b4020a975750fe1a34c86de9e80620be6074 Mon Sep 17 00:00:00 2001 From: Simone Gotti Date: Thu, 25 Jul 2019 10:46:02 +0200 Subject: [PATCH] db: use context functions Use the go sql context functions (ExecContext, QueryContext etc...) The context is saved inside Tx so the library users should only pass it one time to the db.Do function. --- internal/db/create.go | 7 +-- internal/db/db.go | 30 ++++++------ internal/services/configstore/action/org.go | 10 ++-- .../services/configstore/action/project.go | 8 ++-- .../configstore/action/projectgroup.go | 10 ++-- .../configstore/action/remotesource.go | 6 +-- .../services/configstore/action/secret.go | 10 ++-- internal/services/configstore/action/user.go | 18 +++---- .../services/configstore/action/variable.go | 8 ++-- internal/services/configstore/api/org.go | 7 +-- internal/services/configstore/api/project.go | 18 +++---- .../services/configstore/api/projectgroup.go | 24 ++++++---- .../services/configstore/api/remotesource.go | 6 ++- internal/services/configstore/api/secret.go | 2 +- internal/services/configstore/api/user.go | 12 +++-- internal/services/configstore/api/variable.go | 2 +- .../services/configstore/configstore_test.go | 24 +++++----- .../services/configstore/readdb/readdb.go | 40 ++++++++-------- .../configstore/readdb/remotesource.go | 5 +- internal/services/gateway/gateway.go | 12 ++--- internal/services/runservice/action/action.go | 6 +-- internal/services/runservice/api/api.go | 9 ++-- internal/services/runservice/readdb/readdb.go | 48 +++++++++---------- 23 files changed, 170 insertions(+), 152 deletions(-) diff --git a/internal/db/create.go b/internal/db/create.go index 9687947..62a1bbf 100644 --- a/internal/db/create.go +++ b/internal/db/create.go @@ -15,6 +15,7 @@ package db import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" @@ -27,10 +28,10 @@ const dbVersionTableDDLTmpl = ` const dbVersion = 1 -func (db *DB) Create(stmts []string) error { +func (db *DB) Create(ctx context.Context, stmts []string) error { sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) - err := db.Do(func(tx *Tx) error { + err := db.Do(ctx, func(tx *Tx) error { if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil { return errors.Errorf("failed to create dbversion table: %w", err) } @@ -40,7 +41,7 @@ func (db *DB) Create(stmts []string) error { return err } - err = db.Do(func(tx *Tx) error { + err = db.Do(ctx, func(tx *Tx) error { var version sql.NullInt64 q, args, err := sb.Select("max(version)").From("dbversion").ToSql() if err != nil { diff --git a/internal/db/db.go b/internal/db/db.go index d3402dc..6ac695c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -145,16 +145,17 @@ func NewDB(dbType Type, dbConnString string) (*DB, error) { // sql driver doesn't support concurrent statements on the same // connection/transaction) type Tx struct { - db *DB - tx *sql.Tx + db *DB + tx *sql.Tx + ctx context.Context } func (db *DB) Close() error { return db.db.Close() } -func (db *DB) Conn() (*sql.Conn, error) { - return db.db.Conn(context.TODO()) +func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) { + return db.db.Conn(ctx) } func (db *DB) NewUnstartedTx() *Tx { @@ -163,19 +164,19 @@ func (db *DB) NewUnstartedTx() *Tx { } } -func (db *DB) NewTx() (*Tx, error) { +func (db *DB) NewTx(ctx context.Context) (*Tx, error) { tx := db.NewUnstartedTx() - if err := tx.Start(); err != nil { + if err := tx.Start(ctx); err != nil { return nil, err } return tx, nil } -func (db *DB) Do(f func(tx *Tx) error) error { +func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error { retries := 0 for { - err := db.do(f) + err := db.do(ctx, f) if err != nil { var sqerr sqlite3.Error if errors.As(err, &sqerr) { @@ -191,8 +192,8 @@ func (db *DB) Do(f func(tx *Tx) error) error { } } -func (db *DB) do(f func(tx *Tx) error) error { - tx, err := db.NewTx() +func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error { + tx, err := db.NewTx(ctx) if err != nil { return err } @@ -209,7 +210,7 @@ func (db *DB) do(f func(tx *Tx) error) error { return tx.Commit() } -func (tx *Tx) Start() error { +func (tx *Tx) Start(ctx context.Context) error { wtx, err := tx.db.db.Begin() if err != nil { return err @@ -221,6 +222,7 @@ func (tx *Tx) Start() error { } } tx.tx = wtx + tx.ctx = ctx return nil } @@ -240,19 +242,19 @@ func (tx *Tx) Rollback() error { func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { query = tx.db.data.translate(query) - r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...) + r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...) return r, err } func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { query = tx.db.data.translate(query) - r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...) + r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...) return r, err } func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { query = tx.db.data.translate(query) - return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...) + return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...) } func (tx *Tx) CurTime() (time.Time, error) { diff --git a/internal/services/configstore/action/org.go b/internal/services/configstore/action/org.go index d5db64e..c467356 100644 --- a/internal/services/configstore/action/org.go +++ b/internal/services/configstore/action/org.go @@ -44,7 +44,7 @@ func orgMemberResponse(orgUser *readdb.OrgUser) *OrgMemberResponse { func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) { var orgUsers []*readdb.OrgUser - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error org, err := h.readDB.GetOrg(tx, orgRef) if err != nil { @@ -85,7 +85,7 @@ func (h *ActionHandler) CreateOrg(ctx context.Context, org *types.Organization) cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -182,7 +182,7 @@ func (h *ActionHandler) DeleteOrg(ctx context.Context, orgRef string) error { var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check org existance org, err = h.readDB.GetOrgByName(tx, orgRef) @@ -232,7 +232,7 @@ func (h *ActionHandler) AddOrgMember(ctx context.Context, orgRef, userRef string var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check existing org org, err = h.readDB.GetOrg(tx, orgRef) @@ -308,7 +308,7 @@ func (h *ActionHandler) RemoveOrgMember(ctx context.Context, orgRef, userRef str var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check existing org org, err = h.readDB.GetOrg(tx, orgRef) diff --git a/internal/services/configstore/action/project.go b/internal/services/configstore/action/project.go index 1847d05..5afa611 100644 --- a/internal/services/configstore/action/project.go +++ b/internal/services/configstore/action/project.go @@ -72,7 +72,7 @@ func (h *ActionHandler) CreateProject(ctx context.Context, project *types.Projec var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID) if err != nil { @@ -167,7 +167,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project exists p, err := h.readDB.GetProject(tx, req.ProjectRef) @@ -210,7 +210,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq return err } if ap != nil { - return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name,pp)) + return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name, pp)) } } @@ -269,7 +269,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project existance diff --git a/internal/services/configstore/action/projectgroup.go b/internal/services/configstore/action/projectgroup.go index c807270..e46cb93 100644 --- a/internal/services/configstore/action/projectgroup.go +++ b/internal/services/configstore/action/projectgroup.go @@ -30,7 +30,7 @@ import ( func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) { var projectGroups []*types.ProjectGroup - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) if err != nil { @@ -53,7 +53,7 @@ func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGro func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) { var projects []*types.Project - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) if err != nil { @@ -112,7 +112,7 @@ func (h *ActionHandler) CreateProjectGroup(ctx context.Context, projectGroup *ty var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID) if err != nil { return err @@ -184,7 +184,7 @@ func (h *ActionHandler) UpdateProjectGroup(ctx context.Context, req *UpdateProje var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project exists pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef) @@ -280,7 +280,7 @@ func (h *ActionHandler) DeleteProjectGroup(ctx context.Context, projectGroupRef var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project group existance diff --git a/internal/services/configstore/action/remotesource.go b/internal/services/configstore/action/remotesource.go index 8976e60..4dea7b3 100644 --- a/internal/services/configstore/action/remotesource.go +++ b/internal/services/configstore/action/remotesource.go @@ -74,7 +74,7 @@ func (h *ActionHandler) CreateRemoteSource(ctx context.Context, remoteSource *ty cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -129,7 +129,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check remotesource exists @@ -190,7 +190,7 @@ func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check remoteSource existance diff --git a/internal/services/configstore/action/secret.go b/internal/services/configstore/action/secret.go index 9b6d370..f2410ee 100644 --- a/internal/services/configstore/action/secret.go +++ b/internal/services/configstore/action/secret.go @@ -29,7 +29,7 @@ import ( func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) { var secret *types.Secret - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error secret, err = h.readDB.GetSecretByID(tx, secretID) return err @@ -47,7 +47,7 @@ func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types. func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) { var secrets []*types.Secret - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { return err @@ -105,7 +105,7 @@ func (h *ActionHandler) CreateSecret(ctx context.Context, secret *types.Secret) cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -168,7 +168,7 @@ func (h *ActionHandler) UpdateSecret(ctx context.Context, req *UpdateSecretReque // changegroup is the secret name // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID) @@ -238,7 +238,7 @@ func (h *ActionHandler) DeleteSecret(ctx context.Context, parentType types.Confi var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { diff --git a/internal/services/configstore/action/user.go b/internal/services/configstore/action/user.go index 0d005d3..e2dea83 100644 --- a/internal/services/configstore/action/user.go +++ b/internal/services/configstore/action/user.go @@ -50,7 +50,7 @@ func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) var rs *types.RemoteSource // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -156,7 +156,7 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error { var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check user existance @@ -206,7 +206,7 @@ func (h *ActionHandler) UpdateUser(ctx context.Context, req *UpdateUserRequest) var user *types.User // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -289,7 +289,7 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -374,7 +374,7 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -444,7 +444,7 @@ func (h *ActionHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLAReque var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -518,7 +518,7 @@ func (h *ActionHandler) CreateUserToken(ctx context.Context, userRef, tokenName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -583,7 +583,7 @@ func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -644,7 +644,7 @@ func userOrgsResponse(userOrg *readdb.UserOrg) *UserOrgsResponse { func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) { var userOrgs []*readdb.UserOrg - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err := h.readDB.GetUser(tx, userRef) if err != nil { diff --git a/internal/services/configstore/action/variable.go b/internal/services/configstore/action/variable.go index 0dfa59c..d7a5887 100644 --- a/internal/services/configstore/action/variable.go +++ b/internal/services/configstore/action/variable.go @@ -29,7 +29,7 @@ import ( func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) { var variables []*types.Variable - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { return err @@ -81,7 +81,7 @@ func (h *ActionHandler) CreateVariable(ctx context.Context, variable *types.Vari cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -144,7 +144,7 @@ func (h *ActionHandler) UpdateVariable(ctx context.Context, req *UpdateVariableR // changegroup is the variable name // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID) @@ -214,7 +214,7 @@ func (h *ActionHandler) DeleteVariable(ctx context.Context, parentType types.Con var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { diff --git a/internal/services/configstore/api/org.go b/internal/services/configstore/api/org.go index ca5d750..a0d57d8 100644 --- a/internal/services/configstore/api/org.go +++ b/internal/services/configstore/api/org.go @@ -40,11 +40,12 @@ func NewOrgHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgHandler { } func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) orgRef := vars["orgref"] var org *types.Organization - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error org, err = h.readDB.GetOrg(tx, orgRef) return err @@ -105,7 +106,6 @@ func NewDeleteOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *DeleteOr } func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.log.Infof("deleteorghandler") ctx := r.Context() vars := mux.Vars(r) @@ -136,6 +136,7 @@ func NewOrgsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgsHandler { } func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -163,7 +164,7 @@ func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { start := query.Get("start") var orgs []*types.Organization - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error orgs, err = h.readDB.GetOrgs(tx, start, limit, asc) return err diff --git a/internal/services/configstore/api/project.go b/internal/services/configstore/api/project.go index fe6e29c..0510792 100644 --- a/internal/services/configstore/api/project.go +++ b/internal/services/configstore/api/project.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/json" "net/http" "net/url" @@ -43,18 +44,18 @@ type Project struct { GlobalVisibility types.Visibility } -func projectResponse(readDB *readdb.ReadDB, project *types.Project) (*Project, error) { - r, err := projectsResponse(readDB, []*types.Project{project}) +func projectResponse(ctx context.Context, readDB *readdb.ReadDB, project *types.Project) (*Project, error) { + r, err := projectsResponse(ctx, readDB, []*types.Project{project}) if err != nil { return nil, err } return r[0], nil } -func projectsResponse(readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { +func projectsResponse(ctx context.Context, readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { resProjects := make([]*Project, len(projects)) - err := readDB.Do(func(tx *db.Tx) error { + err := readDB.Do(ctx, func(tx *db.Tx) error { for i, project := range projects { pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID) if err != nil { @@ -132,6 +133,7 @@ func NewProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectHandle } func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) projectRef, err := url.PathUnescape(vars["projectref"]) if err != nil { @@ -140,7 +142,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } var project *types.Project - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { var err error project, err = h.readDB.GetProject(tx, projectRef) return err @@ -156,7 +158,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -193,7 +195,7 @@ func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -241,7 +243,7 @@ func (h *UpdateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return diff --git a/internal/services/configstore/api/projectgroup.go b/internal/services/configstore/api/projectgroup.go index cd148ca..4b025ad 100644 --- a/internal/services/configstore/api/projectgroup.go +++ b/internal/services/configstore/api/projectgroup.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/json" "net/http" "net/url" @@ -43,18 +44,18 @@ type ProjectGroup struct { GlobalVisibility types.Visibility } -func projectGroupResponse(readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { - r, err := projectGroupsResponse(readDB, []*types.ProjectGroup{projectGroup}) +func projectGroupResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { + r, err := projectGroupsResponse(ctx, readDB, []*types.ProjectGroup{projectGroup}) if err != nil { return nil, err } return r[0], nil } -func projectGroupsResponse(readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { +func projectGroupsResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { resProjectGroups := make([]*ProjectGroup, len(projectGroups)) - err := readDB.Do(func(tx *db.Tx) error { + err := readDB.Do(ctx, func(tx *db.Tx) error { for i, projectGroup := range projectGroups { pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID) if err != nil { @@ -100,7 +101,9 @@ func NewProjectGroupHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectG } func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) + projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) if err != nil { httpError(w, util.NewErrBadRequest(err)) @@ -108,7 +111,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) } var projectGroup *types.ProjectGroup - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef) return err @@ -124,7 +127,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -148,6 +151,7 @@ func NewProjectGroupProjectsHandler(logger *zap.Logger, ah *action.ActionHandler func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() vars := mux.Vars(r) + projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) if err != nil { httpError(w, util.NewErrBadRequest(err)) @@ -160,7 +164,7 @@ func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.R return } - resProjects, err := projectsResponse(h.readDB, projects) + resProjects, err := projectsResponse(ctx, h.readDB, projects) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -196,7 +200,7 @@ func (h *ProjectGroupSubgroupsHandler) ServeHTTP(w http.ResponseWriter, r *http. return } - resProjectGroups, err := projectGroupsResponse(h.readDB, projectGroups) + resProjectGroups, err := projectGroupsResponse(ctx, h.readDB, projectGroups) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -233,7 +237,7 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -281,7 +285,7 @@ func (h *UpdateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return diff --git a/internal/services/configstore/api/remotesource.go b/internal/services/configstore/api/remotesource.go index 3c5ab5c..a3bc0dc 100644 --- a/internal/services/configstore/api/remotesource.go +++ b/internal/services/configstore/api/remotesource.go @@ -40,11 +40,12 @@ func NewRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSo } func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) rsRef := vars["remotesourceref"] var remoteSource *types.RemoteSource - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef) return err @@ -171,6 +172,7 @@ func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteS } func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -197,7 +199,7 @@ func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) start := query.Get("start") - remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc) + remoteSources, err := h.readDB.GetRemoteSources(ctx, start, limit, asc) if err != nil { h.log.Errorf("err: %+v", err) httpError(w, err) diff --git a/internal/services/configstore/api/secret.go b/internal/services/configstore/api/secret.go index 88172e9..38f525d 100644 --- a/internal/services/configstore/api/secret.go +++ b/internal/services/configstore/api/secret.go @@ -94,7 +94,7 @@ func (h *SecretsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resSecrets[i] = &Secret{Secret: s} } - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { // populate parent path for _, s := range resSecrets { pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID) diff --git a/internal/services/configstore/api/user.go b/internal/services/configstore/api/user.go index 078294f..55d8b81 100644 --- a/internal/services/configstore/api/user.go +++ b/internal/services/configstore/api/user.go @@ -41,11 +41,12 @@ func NewUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UserHandler { } func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) userRef := vars["userref"] var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) return err @@ -198,6 +199,7 @@ func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler { } func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -232,7 +234,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "bytoken": token := query.Get("token") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByTokenValue(tx, token) return err @@ -250,7 +252,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "bylinkedaccount": linkedAccountID := query.Get("linkedaccountid") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID) return err @@ -269,7 +271,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { remoteUserID := query.Get("remoteuserid") remoteSourceID := query.Get("remotesourceid") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID) return err @@ -286,7 +288,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { users = []*types.User{user} default: // default query - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error users, err = h.readDB.GetUsers(tx, start, limit, asc) return err diff --git a/internal/services/configstore/api/variable.go b/internal/services/configstore/api/variable.go index f3076c6..4920300 100644 --- a/internal/services/configstore/api/variable.go +++ b/internal/services/configstore/api/variable.go @@ -67,7 +67,7 @@ func (h *VariablesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for i, v := range variables { resVariables[i] = &Variable{Variable: v} } - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { // populate parent path for _, v := range resVariables { pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID) diff --git a/internal/services/configstore/configstore_test.go b/internal/services/configstore/configstore_test.go index 80fa885..6a867d6 100644 --- a/internal/services/configstore/configstore_test.go +++ b/internal/services/configstore/configstore_test.go @@ -99,9 +99,9 @@ func setupConfigstore(t *testing.T, ctx context.Context, dir string) (*Configsto return cs, tetcd } -func getProjects(cs *Configstore) ([]*types.Project, error) { +func getProjects(ctx context.Context, cs *Configstore) ([]*types.Project, error) { var projects []*types.Project - err := cs.readDB.Do(func(tx *db.Tx) error { + err := cs.readDB.Do(ctx, func(tx *db.Tx) error { var err error projects, err = cs.readDB.GetAllProjects(tx) return err @@ -109,9 +109,9 @@ func getProjects(cs *Configstore) ([]*types.Project, error) { return projects, err } -func getUsers(cs *Configstore) ([]*types.User, error) { +func getUsers(ctx context.Context, cs *Configstore) ([]*types.User, error) { var users []*types.User - err := cs.readDB.Do(func(tx *db.Tx) error { + err := cs.readDB.Do(ctx, func(tx *db.Tx) error { var err error users, err = cs.readDB.GetUsers(tx, "", 0, true) return err @@ -242,12 +242,12 @@ func TestResync(t *testing.T) { time.Sleep(5 * time.Second) - users1, err := getUsers(cs1) + users1, err := getUsers(ctx, cs1) if err != nil { t.Fatalf("err: %v", err) } - users2, err := getUsers(cs2) + users2, err := getUsers(ctx, cs2) if err != nil { t.Fatalf("err: %v", err) } @@ -275,12 +275,12 @@ func TestResync(t *testing.T) { time.Sleep(5 * time.Second) - users1, err = getUsers(cs1) + users1, err = getUsers(ctx, cs1) if err != nil { t.Fatalf("err: %v", err) } - users3, err := getUsers(cs3) + users3, err := getUsers(ctx, cs3) if err != nil { t.Fatalf("err: %v", err) } @@ -349,7 +349,7 @@ func TestUser(t *testing.T) { } }) t.Run("concurrent user with same name creation", func(t *testing.T) { - prevUsers, err := getUsers(cs) + prevUsers, err := getUsers(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -364,7 +364,7 @@ func TestUser(t *testing.T) { time.Sleep(5 * time.Second) - users, err := getUsers(cs) + users, err := getUsers(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -494,7 +494,7 @@ func TestProjectGroupsAndProjects(t *testing.T) { }) t.Run("concurrent project with same name creation", func(t *testing.T) { - prevProjects, err := getProjects(cs) + prevProjects, err := getProjects(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -511,7 +511,7 @@ func TestProjectGroupsAndProjects(t *testing.T) { time.Sleep(1 * time.Second) - projects, err := getProjects(cs) + projects, err := getProjects(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } diff --git a/internal/services/configstore/readdb/readdb.go b/internal/services/configstore/readdb/readdb.go index ad29925..3522de0 100644 --- a/internal/services/configstore/readdb/readdb.go +++ b/internal/services/configstore/readdb/readdb.go @@ -74,7 +74,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd. } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return nil, err } @@ -112,7 +112,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error { return nil } -func (r *ReadDB) ResetDB() error { +func (r *ReadDB) ResetDB(ctx context.Context) error { // TODO(sgotti) this needs to be protected by a mutex r.rdb.Close() @@ -127,7 +127,7 @@ func (r *ReadDB) ResetDB() error { } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return err } @@ -136,7 +136,7 @@ func (r *ReadDB) ResetDB() error { return nil } -func (r *ReadDB) SyncFromDump() (string, error) { +func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) { dumpIndex, err := r.dm.GetLastDataStatus() if err != nil && err != ostypes.ErrNotExist { return "", err @@ -168,7 +168,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } dumpf.Close() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, de := range dumpEntries { action := &datamanager.Action{ ActionType: datamanager.ActionTypePut, @@ -188,7 +188,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil { return err } @@ -201,9 +201,9 @@ func (r *ReadDB) SyncFromDump() (string, error) { return dumpIndex.WalSequence, nil } -func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { +func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) { insertfunc := func(walFiles []*datamanager.WalFile) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { for _, walFile := range walFiles { walFilef, err := r.dm.ReadWal(walFile.WalSequence) if err != nil { @@ -263,7 +263,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncRDB(ctx context.Context) error { // get the last committed storage wal sequence saved in the rdb curWalSeq := "" - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error curWalSeq, err = r.GetCommittedWalSequence(tx) if err != nil { @@ -317,12 +317,12 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { if doFullSync { r.log.Infof("doing a full sync from dump") - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return err } var err error - curWalSeq, err = r.SyncFromDump() + curWalSeq, err = r.SyncFromDump(ctx) if err != nil { return err } @@ -335,7 +335,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { // etcd since wals are first committed to objectstorage and then in etcd we // would like to avoid to store in rdb something that is not yet marked as // committedstorage in etcd - curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) + curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal) if err != nil { return errors.Errorf("failed to sync from wals: %w", err) } @@ -358,7 +358,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } r.log.Infof("syncing from wals") - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertRevision(tx, revision); err != nil { return err } @@ -406,7 +406,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } func (r *ReadDB) Run(ctx context.Context) error { - revision, err := r.GetRevision() + revision, err := r.GetRevision(ctx) if err != nil { return err } @@ -475,7 +475,7 @@ func (r *ReadDB) Run(ctx context.Context) error { // lost/reset func (r *ReadDB) handleEvents(ctx context.Context) error { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) if err != nil { if err == sql.ErrNoRows { @@ -508,7 +508,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { // if theres a wal seq epoch change something happened to etcd, usually (if // the user hasn't messed up with etcd keys) this means etcd has been reset @@ -704,8 +704,8 @@ func (r *ReadDB) applyAction(tx *db.Tx, action *datamanager.Action) error { return nil } -func (r *ReadDB) Do(f func(tx *db.Tx) error) error { - return r.rdb.Do(f) +func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error { + return r.rdb.Do(ctx, f) } func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { @@ -724,10 +724,10 @@ func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { return nil } -func (r *ReadDB) GetRevision() (int64, error) { +func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) return err diff --git a/internal/services/configstore/readdb/remotesource.go b/internal/services/configstore/readdb/remotesource.go index 1b8b2e0..e210a43 100644 --- a/internal/services/configstore/readdb/remotesource.go +++ b/internal/services/configstore/readdb/remotesource.go @@ -15,6 +15,7 @@ package readdb import ( + "context" "database/sql" "encoding/json" @@ -139,7 +140,7 @@ func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc return s } -func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { +func (r *ReadDB) GetRemoteSources(ctx context.Context, startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { var remoteSources []*types.RemoteSource s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc) @@ -149,7 +150,7 @@ func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc b return nil, errors.Errorf("failed to build query: %w", err) } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { rows, err := tx.Query(q, args...) if err != nil { return err diff --git a/internal/services/gateway/gateway.go b/internal/services/gateway/gateway.go index dbe6371..ddee92a 100644 --- a/internal/services/gateway/gateway.go +++ b/internal/services/gateway/gateway.go @@ -142,12 +142,12 @@ func (g *Gateway) Run(ctx context.Context) error { return h } -if len(g.c.Web.AllowedOrigins) > 0 { - corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) - corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) - corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins) - corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) -} + if len(g.c.Web.AllowedOrigins) > 0 { + corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) + corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) + corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins) + corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) + } webhooksHandler := api.NewWebhooksHandler(logger, g.ah, g.configstoreClient, g.runserviceClient, g.c.APIExposedURL) diff --git a/internal/services/runservice/action/action.go b/internal/services/runservice/action/action.go index 2388d4b..afba39b 100644 --- a/internal/services/runservice/action/action.go +++ b/internal/services/runservice/action/action.go @@ -381,7 +381,7 @@ func (h *ActionHandler) saveRun(ctx context.Context, rb *types.RunBundle, runcgt run := rb.Run rc := rb.Rc - c, cgt, err := h.getRunCounter(run.Group) + c, cgt, err := h.getRunCounter(ctx, run.Group) h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt)) if err != nil { return err @@ -570,7 +570,7 @@ func (h *ActionHandler) DeleteExecutor(ctx context.Context, executorID string) e return nil } -func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { +func (h *ActionHandler) getRunCounter(ctx context.Context, group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { // use the first group dir after the root pl := util.PathList(group) if len(pl) < 2 { @@ -579,7 +579,7 @@ func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.Change var c uint64 var cgt *datamanager.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error c, err = h.readDB.GetRunCounterOST(tx, pl[1]) if err != nil { diff --git a/internal/services/runservice/api/api.go b/internal/services/runservice/api/api.go index 75667ea..beffd66 100644 --- a/internal/services/runservice/api/api.go +++ b/internal/services/runservice/api/api.go @@ -323,12 +323,13 @@ func NewChangeGroupsUpdateTokensHandler(logger *zap.Logger, readDB *readdb.ReadD } func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() groups := query["changegroup"] var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups) return err @@ -372,6 +373,7 @@ func NewRunHandler(logger *zap.Logger, e *etcd.Store, dm *datamanager.DataManage } func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) runID := vars["runid"] @@ -381,7 +383,7 @@ func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var run *types.Run var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error run, err = h.readDB.GetRun(tx, runID) if err != nil { @@ -447,6 +449,7 @@ func NewRunsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RunsHandler { } func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() phaseFilter := types.RunPhaseFromStringSlice(query["phase"]) resultFilter := types.RunResultFromStringSlice(query["result"]) @@ -482,7 +485,7 @@ func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var runs []*types.Run var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder) if err != nil { diff --git a/internal/services/runservice/readdb/readdb.go b/internal/services/runservice/readdb/readdb.go index 405410c..e82ef63 100644 --- a/internal/services/runservice/readdb/readdb.go +++ b/internal/services/runservice/readdb/readdb.go @@ -119,7 +119,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd. } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return nil, err } @@ -150,7 +150,7 @@ func (r *ReadDB) IsInitialized() bool { // Initialize populates the readdb with the current etcd data and save the // revision to then feed it with the etcd events func (r *ReadDB) Initialize(ctx context.Context) error { - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return errors.Errorf("failed to reset db: %w", err) } if err := r.SyncObjectStorage(ctx); err != nil { @@ -162,7 +162,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error { return nil } -func (r *ReadDB) ResetDB() error { +func (r *ReadDB) ResetDB(ctx context.Context) error { // TODO(sgotti) this needs to be protected by a mutex r.rdb.Close() @@ -177,7 +177,7 @@ func (r *ReadDB) ResetDB() error { } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return err } @@ -187,7 +187,7 @@ func (r *ReadDB) ResetDB() error { } func (r *ReadDB) SyncRDB(ctx context.Context) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { // Do pagination to limit the number of keys per request var revision int64 key := common.EtcdRunsDir @@ -256,7 +256,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } func (r *ReadDB) Run(ctx context.Context) error { - revision, err := r.GetRevision() + revision, err := r.GetRevision(ctx) if err != nil { return err } @@ -333,7 +333,7 @@ func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error { var revision int64 var lastRuns []*RunData - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) if err != nil { @@ -390,7 +390,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). r.dbWriteLock.Lock() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, ev := range wresp.Events { if err := r.handleEvent(tx, ev, &wresp); err != nil { return err @@ -503,7 +503,7 @@ func (r *ReadDB) handleChangeGroupEvent(tx *db.Tx, ev *etcdclientv3.Event, wresp func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { // get the last committed storage wal sequence saved in the rdb curWalSeq := "" - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error curWalSeq, err = r.GetCommittedWalSequenceOST(tx) if err != nil { @@ -557,12 +557,12 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { if doFullSync { r.log.Infof("doing a full sync from dump") - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return err } var err error - curWalSeq, err = r.SyncFromDump() + curWalSeq, err = r.SyncFromDump(ctx) if err != nil { return err } @@ -575,7 +575,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { // since wals are first committed to objectstorage and then in etcd we would like to // avoid to store in rdb something that is not yet marked as committedstorage // in etcd - curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) + curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal) if err != nil { return errors.Errorf("failed to sync from wals: %w", err) } @@ -598,7 +598,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { } r.log.Infof("syncing from wals") - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := insertRevisionOST(tx, revision); err != nil { return err } @@ -645,7 +645,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { return err } -func (r *ReadDB) SyncFromDump() (string, error) { +func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) { dumpIndex, err := r.dm.GetLastDataStatus() if err != nil && err != ostypes.ErrNotExist { return "", err @@ -677,7 +677,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } dumpf.Close() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, de := range dumpEntries { action := &datamanager.Action{ ActionType: datamanager.ActionTypePut, @@ -697,7 +697,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil { return err } @@ -710,9 +710,9 @@ func (r *ReadDB) SyncFromDump() (string, error) { return dumpIndex.WalSequence, nil } -func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { +func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) { insertfunc := func(walFiles []*datamanager.WalFile) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { for _, walFile := range walFiles { walFilef, err := r.dm.ReadWal(walFile.WalSequence) if err != nil { @@ -771,7 +771,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) handleEventsOST(ctx context.Context) error { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) if err != nil { if err == sql.ErrNoRows { @@ -805,7 +805,7 @@ func (r *ReadDB) handleEventsOST(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). r.dbWriteLock.Lock() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { // if theres a wal seq epoch change something happened to etcd, usually (if // the user hasn't messed up with etcd keys) this means etcd has been reset @@ -949,11 +949,11 @@ func (r *ReadDB) handleWalEvent(tx *db.Tx, we *datamanager.WatchElement) error { return nil } -func (r *ReadDB) Do(f func(tx *db.Tx) error) error { +func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error { if !r.IsInitialized() { return errors.Errorf("db not initialized") } - return r.rdb.Do(f) + return r.rdb.Do(ctx, f) } func insertRevision(tx *db.Tx, revision int64) error { @@ -1073,10 +1073,10 @@ func insertChangeGroupRevision(tx *db.Tx, changegroupID string, revision int64) return nil } -func (r *ReadDB) GetRevision() (int64, error) { +func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) return err