diff --git a/internal/db/create.go b/internal/db/create.go index 9687947..62a1bbf 100644 --- a/internal/db/create.go +++ b/internal/db/create.go @@ -15,6 +15,7 @@ package db import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" @@ -27,10 +28,10 @@ const dbVersionTableDDLTmpl = ` const dbVersion = 1 -func (db *DB) Create(stmts []string) error { +func (db *DB) Create(ctx context.Context, stmts []string) error { sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) - err := db.Do(func(tx *Tx) error { + err := db.Do(ctx, func(tx *Tx) error { if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil { return errors.Errorf("failed to create dbversion table: %w", err) } @@ -40,7 +41,7 @@ func (db *DB) Create(stmts []string) error { return err } - err = db.Do(func(tx *Tx) error { + err = db.Do(ctx, func(tx *Tx) error { var version sql.NullInt64 q, args, err := sb.Select("max(version)").From("dbversion").ToSql() if err != nil { diff --git a/internal/db/db.go b/internal/db/db.go index d3402dc..6ac695c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -145,16 +145,17 @@ func NewDB(dbType Type, dbConnString string) (*DB, error) { // sql driver doesn't support concurrent statements on the same // connection/transaction) type Tx struct { - db *DB - tx *sql.Tx + db *DB + tx *sql.Tx + ctx context.Context } func (db *DB) Close() error { return db.db.Close() } -func (db *DB) Conn() (*sql.Conn, error) { - return db.db.Conn(context.TODO()) +func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) { + return db.db.Conn(ctx) } func (db *DB) NewUnstartedTx() *Tx { @@ -163,19 +164,19 @@ func (db *DB) NewUnstartedTx() *Tx { } } -func (db *DB) NewTx() (*Tx, error) { +func (db *DB) NewTx(ctx context.Context) (*Tx, error) { tx := db.NewUnstartedTx() - if err := tx.Start(); err != nil { + if err := tx.Start(ctx); err != nil { return nil, err } return tx, nil } -func (db *DB) Do(f func(tx *Tx) error) error { +func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error { retries := 0 for { - err := db.do(f) + err := db.do(ctx, f) if err != nil { var sqerr sqlite3.Error if errors.As(err, &sqerr) { @@ -191,8 +192,8 @@ func (db *DB) Do(f func(tx *Tx) error) error { } } -func (db *DB) do(f func(tx *Tx) error) error { - tx, err := db.NewTx() +func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error { + tx, err := db.NewTx(ctx) if err != nil { return err } @@ -209,7 +210,7 @@ func (db *DB) do(f func(tx *Tx) error) error { return tx.Commit() } -func (tx *Tx) Start() error { +func (tx *Tx) Start(ctx context.Context) error { wtx, err := tx.db.db.Begin() if err != nil { return err @@ -221,6 +222,7 @@ func (tx *Tx) Start() error { } } tx.tx = wtx + tx.ctx = ctx return nil } @@ -240,19 +242,19 @@ func (tx *Tx) Rollback() error { func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { query = tx.db.data.translate(query) - r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...) + r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...) return r, err } func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { query = tx.db.data.translate(query) - r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...) + r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...) return r, err } func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { query = tx.db.data.translate(query) - return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...) + return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...) } func (tx *Tx) CurTime() (time.Time, error) { diff --git a/internal/services/configstore/action/org.go b/internal/services/configstore/action/org.go index d5db64e..c467356 100644 --- a/internal/services/configstore/action/org.go +++ b/internal/services/configstore/action/org.go @@ -44,7 +44,7 @@ func orgMemberResponse(orgUser *readdb.OrgUser) *OrgMemberResponse { func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) { var orgUsers []*readdb.OrgUser - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error org, err := h.readDB.GetOrg(tx, orgRef) if err != nil { @@ -85,7 +85,7 @@ func (h *ActionHandler) CreateOrg(ctx context.Context, org *types.Organization) cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -182,7 +182,7 @@ func (h *ActionHandler) DeleteOrg(ctx context.Context, orgRef string) error { var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check org existance org, err = h.readDB.GetOrgByName(tx, orgRef) @@ -232,7 +232,7 @@ func (h *ActionHandler) AddOrgMember(ctx context.Context, orgRef, userRef string var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check existing org org, err = h.readDB.GetOrg(tx, orgRef) @@ -308,7 +308,7 @@ func (h *ActionHandler) RemoveOrgMember(ctx context.Context, orgRef, userRef str var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check existing org org, err = h.readDB.GetOrg(tx, orgRef) diff --git a/internal/services/configstore/action/project.go b/internal/services/configstore/action/project.go index 1847d05..5afa611 100644 --- a/internal/services/configstore/action/project.go +++ b/internal/services/configstore/action/project.go @@ -72,7 +72,7 @@ func (h *ActionHandler) CreateProject(ctx context.Context, project *types.Projec var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID) if err != nil { @@ -167,7 +167,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project exists p, err := h.readDB.GetProject(tx, req.ProjectRef) @@ -210,7 +210,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq return err } if ap != nil { - return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name,pp)) + return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name, pp)) } } @@ -269,7 +269,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project existance diff --git a/internal/services/configstore/action/projectgroup.go b/internal/services/configstore/action/projectgroup.go index c807270..e46cb93 100644 --- a/internal/services/configstore/action/projectgroup.go +++ b/internal/services/configstore/action/projectgroup.go @@ -30,7 +30,7 @@ import ( func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) { var projectGroups []*types.ProjectGroup - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) if err != nil { @@ -53,7 +53,7 @@ func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGro func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) { var projects []*types.Project - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) if err != nil { @@ -112,7 +112,7 @@ func (h *ActionHandler) CreateProjectGroup(ctx context.Context, projectGroup *ty var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID) if err != nil { return err @@ -184,7 +184,7 @@ func (h *ActionHandler) UpdateProjectGroup(ctx context.Context, req *UpdateProje var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project exists pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef) @@ -280,7 +280,7 @@ func (h *ActionHandler) DeleteProjectGroup(ctx context.Context, projectGroupRef var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check project group existance diff --git a/internal/services/configstore/action/remotesource.go b/internal/services/configstore/action/remotesource.go index 8976e60..4dea7b3 100644 --- a/internal/services/configstore/action/remotesource.go +++ b/internal/services/configstore/action/remotesource.go @@ -74,7 +74,7 @@ func (h *ActionHandler) CreateRemoteSource(ctx context.Context, remoteSource *ty cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -129,7 +129,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check remotesource exists @@ -190,7 +190,7 @@ func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check remoteSource existance diff --git a/internal/services/configstore/action/secret.go b/internal/services/configstore/action/secret.go index 9b6d370..f2410ee 100644 --- a/internal/services/configstore/action/secret.go +++ b/internal/services/configstore/action/secret.go @@ -29,7 +29,7 @@ import ( func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) { var secret *types.Secret - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error secret, err = h.readDB.GetSecretByID(tx, secretID) return err @@ -47,7 +47,7 @@ func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types. func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) { var secrets []*types.Secret - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { return err @@ -105,7 +105,7 @@ func (h *ActionHandler) CreateSecret(ctx context.Context, secret *types.Secret) cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -168,7 +168,7 @@ func (h *ActionHandler) UpdateSecret(ctx context.Context, req *UpdateSecretReque // changegroup is the secret name // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID) @@ -238,7 +238,7 @@ func (h *ActionHandler) DeleteSecret(ctx context.Context, parentType types.Confi var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { diff --git a/internal/services/configstore/action/user.go b/internal/services/configstore/action/user.go index 0d005d3..e2dea83 100644 --- a/internal/services/configstore/action/user.go +++ b/internal/services/configstore/action/user.go @@ -50,7 +50,7 @@ func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest) var rs *types.RemoteSource // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -156,7 +156,7 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error { var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error // check user existance @@ -206,7 +206,7 @@ func (h *ActionHandler) UpdateUser(ctx context.Context, req *UpdateUserRequest) var user *types.User // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -289,7 +289,7 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -374,7 +374,7 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string) var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -444,7 +444,7 @@ func (h *ActionHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLAReque var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, req.UserRef) if err != nil { @@ -518,7 +518,7 @@ func (h *ActionHandler) CreateUserToken(ctx context.Context, userRef, tokenName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -583,7 +583,7 @@ func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) if err != nil { @@ -644,7 +644,7 @@ func userOrgsResponse(userOrg *readdb.UserOrg) *UserOrgsResponse { func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) { var userOrgs []*readdb.UserOrg - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err := h.readDB.GetUser(tx, userRef) if err != nil { diff --git a/internal/services/configstore/action/variable.go b/internal/services/configstore/action/variable.go index 0dfa59c..d7a5887 100644 --- a/internal/services/configstore/action/variable.go +++ b/internal/services/configstore/action/variable.go @@ -29,7 +29,7 @@ import ( func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) { var variables []*types.Variable - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { return err @@ -81,7 +81,7 @@ func (h *ActionHandler) CreateVariable(ctx context.Context, variable *types.Vari cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)} // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) if err != nil { @@ -144,7 +144,7 @@ func (h *ActionHandler) UpdateVariable(ctx context.Context, req *UpdateVariableR // changegroup is the variable name // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID) @@ -214,7 +214,7 @@ func (h *ActionHandler) DeleteVariable(ctx context.Context, parentType types.Con var cgt *datamanager.ChangeGroupsUpdateToken // must do all the checks in a single transaction to avoid concurrent changes - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) if err != nil { diff --git a/internal/services/configstore/api/org.go b/internal/services/configstore/api/org.go index ca5d750..a0d57d8 100644 --- a/internal/services/configstore/api/org.go +++ b/internal/services/configstore/api/org.go @@ -40,11 +40,12 @@ func NewOrgHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgHandler { } func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) orgRef := vars["orgref"] var org *types.Organization - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error org, err = h.readDB.GetOrg(tx, orgRef) return err @@ -105,7 +106,6 @@ func NewDeleteOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *DeleteOr } func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.log.Infof("deleteorghandler") ctx := r.Context() vars := mux.Vars(r) @@ -136,6 +136,7 @@ func NewOrgsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgsHandler { } func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -163,7 +164,7 @@ func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { start := query.Get("start") var orgs []*types.Organization - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error orgs, err = h.readDB.GetOrgs(tx, start, limit, asc) return err diff --git a/internal/services/configstore/api/project.go b/internal/services/configstore/api/project.go index fe6e29c..0510792 100644 --- a/internal/services/configstore/api/project.go +++ b/internal/services/configstore/api/project.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/json" "net/http" "net/url" @@ -43,18 +44,18 @@ type Project struct { GlobalVisibility types.Visibility } -func projectResponse(readDB *readdb.ReadDB, project *types.Project) (*Project, error) { - r, err := projectsResponse(readDB, []*types.Project{project}) +func projectResponse(ctx context.Context, readDB *readdb.ReadDB, project *types.Project) (*Project, error) { + r, err := projectsResponse(ctx, readDB, []*types.Project{project}) if err != nil { return nil, err } return r[0], nil } -func projectsResponse(readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { +func projectsResponse(ctx context.Context, readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { resProjects := make([]*Project, len(projects)) - err := readDB.Do(func(tx *db.Tx) error { + err := readDB.Do(ctx, func(tx *db.Tx) error { for i, project := range projects { pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID) if err != nil { @@ -132,6 +133,7 @@ func NewProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectHandle } func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) projectRef, err := url.PathUnescape(vars["projectref"]) if err != nil { @@ -140,7 +142,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } var project *types.Project - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { var err error project, err = h.readDB.GetProject(tx, projectRef) return err @@ -156,7 +158,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -193,7 +195,7 @@ func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -241,7 +243,7 @@ func (h *UpdateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProject, err := projectResponse(h.readDB, project) + resProject, err := projectResponse(ctx, h.readDB, project) if httpError(w, err) { h.log.Errorf("err: %+v", err) return diff --git a/internal/services/configstore/api/projectgroup.go b/internal/services/configstore/api/projectgroup.go index cd148ca..4b025ad 100644 --- a/internal/services/configstore/api/projectgroup.go +++ b/internal/services/configstore/api/projectgroup.go @@ -15,6 +15,7 @@ package api import ( + "context" "encoding/json" "net/http" "net/url" @@ -43,18 +44,18 @@ type ProjectGroup struct { GlobalVisibility types.Visibility } -func projectGroupResponse(readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { - r, err := projectGroupsResponse(readDB, []*types.ProjectGroup{projectGroup}) +func projectGroupResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { + r, err := projectGroupsResponse(ctx, readDB, []*types.ProjectGroup{projectGroup}) if err != nil { return nil, err } return r[0], nil } -func projectGroupsResponse(readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { +func projectGroupsResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { resProjectGroups := make([]*ProjectGroup, len(projectGroups)) - err := readDB.Do(func(tx *db.Tx) error { + err := readDB.Do(ctx, func(tx *db.Tx) error { for i, projectGroup := range projectGroups { pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID) if err != nil { @@ -100,7 +101,9 @@ func NewProjectGroupHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectG } func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) + projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) if err != nil { httpError(w, util.NewErrBadRequest(err)) @@ -108,7 +111,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) } var projectGroup *types.ProjectGroup - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { var err error projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef) return err @@ -124,7 +127,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -148,6 +151,7 @@ func NewProjectGroupProjectsHandler(logger *zap.Logger, ah *action.ActionHandler func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() vars := mux.Vars(r) + projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) if err != nil { httpError(w, util.NewErrBadRequest(err)) @@ -160,7 +164,7 @@ func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.R return } - resProjects, err := projectsResponse(h.readDB, projects) + resProjects, err := projectsResponse(ctx, h.readDB, projects) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -196,7 +200,7 @@ func (h *ProjectGroupSubgroupsHandler) ServeHTTP(w http.ResponseWriter, r *http. return } - resProjectGroups, err := projectGroupsResponse(h.readDB, projectGroups) + resProjectGroups, err := projectGroupsResponse(ctx, h.readDB, projectGroups) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -233,7 +237,7 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return @@ -281,7 +285,7 @@ func (h *UpdateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req return } - resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) + resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup) if httpError(w, err) { h.log.Errorf("err: %+v", err) return diff --git a/internal/services/configstore/api/remotesource.go b/internal/services/configstore/api/remotesource.go index 3c5ab5c..a3bc0dc 100644 --- a/internal/services/configstore/api/remotesource.go +++ b/internal/services/configstore/api/remotesource.go @@ -40,11 +40,12 @@ func NewRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSo } func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) rsRef := vars["remotesourceref"] var remoteSource *types.RemoteSource - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef) return err @@ -171,6 +172,7 @@ func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteS } func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -197,7 +199,7 @@ func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) start := query.Get("start") - remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc) + remoteSources, err := h.readDB.GetRemoteSources(ctx, start, limit, asc) if err != nil { h.log.Errorf("err: %+v", err) httpError(w, err) diff --git a/internal/services/configstore/api/secret.go b/internal/services/configstore/api/secret.go index 88172e9..38f525d 100644 --- a/internal/services/configstore/api/secret.go +++ b/internal/services/configstore/api/secret.go @@ -94,7 +94,7 @@ func (h *SecretsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resSecrets[i] = &Secret{Secret: s} } - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { // populate parent path for _, s := range resSecrets { pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID) diff --git a/internal/services/configstore/api/user.go b/internal/services/configstore/api/user.go index 078294f..55d8b81 100644 --- a/internal/services/configstore/api/user.go +++ b/internal/services/configstore/api/user.go @@ -41,11 +41,12 @@ func NewUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UserHandler { } func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) userRef := vars["userref"] var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUser(tx, userRef) return err @@ -198,6 +199,7 @@ func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler { } func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() limitS := query.Get("limit") @@ -232,7 +234,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "bytoken": token := query.Get("token") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByTokenValue(tx, token) return err @@ -250,7 +252,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "bylinkedaccount": linkedAccountID := query.Get("linkedaccountid") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID) return err @@ -269,7 +271,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { remoteUserID := query.Get("remoteuserid") remoteSourceID := query.Get("remotesourceid") var user *types.User - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID) return err @@ -286,7 +288,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { users = []*types.User{user} default: // default query - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error users, err = h.readDB.GetUsers(tx, start, limit, asc) return err diff --git a/internal/services/configstore/api/variable.go b/internal/services/configstore/api/variable.go index f3076c6..4920300 100644 --- a/internal/services/configstore/api/variable.go +++ b/internal/services/configstore/api/variable.go @@ -67,7 +67,7 @@ func (h *VariablesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for i, v := range variables { resVariables[i] = &Variable{Variable: v} } - err = h.readDB.Do(func(tx *db.Tx) error { + err = h.readDB.Do(ctx, func(tx *db.Tx) error { // populate parent path for _, v := range resVariables { pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID) diff --git a/internal/services/configstore/configstore_test.go b/internal/services/configstore/configstore_test.go index 80fa885..6a867d6 100644 --- a/internal/services/configstore/configstore_test.go +++ b/internal/services/configstore/configstore_test.go @@ -99,9 +99,9 @@ func setupConfigstore(t *testing.T, ctx context.Context, dir string) (*Configsto return cs, tetcd } -func getProjects(cs *Configstore) ([]*types.Project, error) { +func getProjects(ctx context.Context, cs *Configstore) ([]*types.Project, error) { var projects []*types.Project - err := cs.readDB.Do(func(tx *db.Tx) error { + err := cs.readDB.Do(ctx, func(tx *db.Tx) error { var err error projects, err = cs.readDB.GetAllProjects(tx) return err @@ -109,9 +109,9 @@ func getProjects(cs *Configstore) ([]*types.Project, error) { return projects, err } -func getUsers(cs *Configstore) ([]*types.User, error) { +func getUsers(ctx context.Context, cs *Configstore) ([]*types.User, error) { var users []*types.User - err := cs.readDB.Do(func(tx *db.Tx) error { + err := cs.readDB.Do(ctx, func(tx *db.Tx) error { var err error users, err = cs.readDB.GetUsers(tx, "", 0, true) return err @@ -242,12 +242,12 @@ func TestResync(t *testing.T) { time.Sleep(5 * time.Second) - users1, err := getUsers(cs1) + users1, err := getUsers(ctx, cs1) if err != nil { t.Fatalf("err: %v", err) } - users2, err := getUsers(cs2) + users2, err := getUsers(ctx, cs2) if err != nil { t.Fatalf("err: %v", err) } @@ -275,12 +275,12 @@ func TestResync(t *testing.T) { time.Sleep(5 * time.Second) - users1, err = getUsers(cs1) + users1, err = getUsers(ctx, cs1) if err != nil { t.Fatalf("err: %v", err) } - users3, err := getUsers(cs3) + users3, err := getUsers(ctx, cs3) if err != nil { t.Fatalf("err: %v", err) } @@ -349,7 +349,7 @@ func TestUser(t *testing.T) { } }) t.Run("concurrent user with same name creation", func(t *testing.T) { - prevUsers, err := getUsers(cs) + prevUsers, err := getUsers(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -364,7 +364,7 @@ func TestUser(t *testing.T) { time.Sleep(5 * time.Second) - users, err := getUsers(cs) + users, err := getUsers(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -494,7 +494,7 @@ func TestProjectGroupsAndProjects(t *testing.T) { }) t.Run("concurrent project with same name creation", func(t *testing.T) { - prevProjects, err := getProjects(cs) + prevProjects, err := getProjects(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -511,7 +511,7 @@ func TestProjectGroupsAndProjects(t *testing.T) { time.Sleep(1 * time.Second) - projects, err := getProjects(cs) + projects, err := getProjects(ctx, cs) if err != nil { t.Fatalf("unexpected err: %v", err) } diff --git a/internal/services/configstore/readdb/readdb.go b/internal/services/configstore/readdb/readdb.go index ad29925..3522de0 100644 --- a/internal/services/configstore/readdb/readdb.go +++ b/internal/services/configstore/readdb/readdb.go @@ -74,7 +74,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd. } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return nil, err } @@ -112,7 +112,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error { return nil } -func (r *ReadDB) ResetDB() error { +func (r *ReadDB) ResetDB(ctx context.Context) error { // TODO(sgotti) this needs to be protected by a mutex r.rdb.Close() @@ -127,7 +127,7 @@ func (r *ReadDB) ResetDB() error { } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return err } @@ -136,7 +136,7 @@ func (r *ReadDB) ResetDB() error { return nil } -func (r *ReadDB) SyncFromDump() (string, error) { +func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) { dumpIndex, err := r.dm.GetLastDataStatus() if err != nil && err != ostypes.ErrNotExist { return "", err @@ -168,7 +168,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } dumpf.Close() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, de := range dumpEntries { action := &datamanager.Action{ ActionType: datamanager.ActionTypePut, @@ -188,7 +188,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil { return err } @@ -201,9 +201,9 @@ func (r *ReadDB) SyncFromDump() (string, error) { return dumpIndex.WalSequence, nil } -func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { +func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) { insertfunc := func(walFiles []*datamanager.WalFile) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { for _, walFile := range walFiles { walFilef, err := r.dm.ReadWal(walFile.WalSequence) if err != nil { @@ -263,7 +263,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncRDB(ctx context.Context) error { // get the last committed storage wal sequence saved in the rdb curWalSeq := "" - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error curWalSeq, err = r.GetCommittedWalSequence(tx) if err != nil { @@ -317,12 +317,12 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { if doFullSync { r.log.Infof("doing a full sync from dump") - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return err } var err error - curWalSeq, err = r.SyncFromDump() + curWalSeq, err = r.SyncFromDump(ctx) if err != nil { return err } @@ -335,7 +335,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { // etcd since wals are first committed to objectstorage and then in etcd we // would like to avoid to store in rdb something that is not yet marked as // committedstorage in etcd - curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) + curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal) if err != nil { return errors.Errorf("failed to sync from wals: %w", err) } @@ -358,7 +358,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } r.log.Infof("syncing from wals") - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertRevision(tx, revision); err != nil { return err } @@ -406,7 +406,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } func (r *ReadDB) Run(ctx context.Context) error { - revision, err := r.GetRevision() + revision, err := r.GetRevision(ctx) if err != nil { return err } @@ -475,7 +475,7 @@ func (r *ReadDB) Run(ctx context.Context) error { // lost/reset func (r *ReadDB) handleEvents(ctx context.Context) error { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) if err != nil { if err == sql.ErrNoRows { @@ -508,7 +508,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { // if theres a wal seq epoch change something happened to etcd, usually (if // the user hasn't messed up with etcd keys) this means etcd has been reset @@ -704,8 +704,8 @@ func (r *ReadDB) applyAction(tx *db.Tx, action *datamanager.Action) error { return nil } -func (r *ReadDB) Do(f func(tx *db.Tx) error) error { - return r.rdb.Do(f) +func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error { + return r.rdb.Do(ctx, f) } func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { @@ -724,10 +724,10 @@ func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { return nil } -func (r *ReadDB) GetRevision() (int64, error) { +func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) return err diff --git a/internal/services/configstore/readdb/remotesource.go b/internal/services/configstore/readdb/remotesource.go index 1b8b2e0..e210a43 100644 --- a/internal/services/configstore/readdb/remotesource.go +++ b/internal/services/configstore/readdb/remotesource.go @@ -15,6 +15,7 @@ package readdb import ( + "context" "database/sql" "encoding/json" @@ -139,7 +140,7 @@ func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc return s } -func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { +func (r *ReadDB) GetRemoteSources(ctx context.Context, startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { var remoteSources []*types.RemoteSource s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc) @@ -149,7 +150,7 @@ func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc b return nil, errors.Errorf("failed to build query: %w", err) } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { rows, err := tx.Query(q, args...) if err != nil { return err diff --git a/internal/services/gateway/gateway.go b/internal/services/gateway/gateway.go index dbe6371..ddee92a 100644 --- a/internal/services/gateway/gateway.go +++ b/internal/services/gateway/gateway.go @@ -142,12 +142,12 @@ func (g *Gateway) Run(ctx context.Context) error { return h } -if len(g.c.Web.AllowedOrigins) > 0 { - corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) - corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) - corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins) - corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) -} + if len(g.c.Web.AllowedOrigins) > 0 { + corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) + corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) + corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins) + corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) + } webhooksHandler := api.NewWebhooksHandler(logger, g.ah, g.configstoreClient, g.runserviceClient, g.c.APIExposedURL) diff --git a/internal/services/runservice/action/action.go b/internal/services/runservice/action/action.go index 2388d4b..afba39b 100644 --- a/internal/services/runservice/action/action.go +++ b/internal/services/runservice/action/action.go @@ -381,7 +381,7 @@ func (h *ActionHandler) saveRun(ctx context.Context, rb *types.RunBundle, runcgt run := rb.Run rc := rb.Rc - c, cgt, err := h.getRunCounter(run.Group) + c, cgt, err := h.getRunCounter(ctx, run.Group) h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt)) if err != nil { return err @@ -570,7 +570,7 @@ func (h *ActionHandler) DeleteExecutor(ctx context.Context, executorID string) e return nil } -func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { +func (h *ActionHandler) getRunCounter(ctx context.Context, group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { // use the first group dir after the root pl := util.PathList(group) if len(pl) < 2 { @@ -579,7 +579,7 @@ func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.Change var c uint64 var cgt *datamanager.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error c, err = h.readDB.GetRunCounterOST(tx, pl[1]) if err != nil { diff --git a/internal/services/runservice/api/api.go b/internal/services/runservice/api/api.go index 75667ea..beffd66 100644 --- a/internal/services/runservice/api/api.go +++ b/internal/services/runservice/api/api.go @@ -323,12 +323,13 @@ func NewChangeGroupsUpdateTokensHandler(logger *zap.Logger, readDB *readdb.ReadD } func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() groups := query["changegroup"] var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups) return err @@ -372,6 +373,7 @@ func NewRunHandler(logger *zap.Logger, e *etcd.Store, dm *datamanager.DataManage } func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() vars := mux.Vars(r) runID := vars["runid"] @@ -381,7 +383,7 @@ func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var run *types.Run var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error run, err = h.readDB.GetRun(tx, runID) if err != nil { @@ -447,6 +449,7 @@ func NewRunsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RunsHandler { } func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() query := r.URL.Query() phaseFilter := types.RunPhaseFromStringSlice(query["phase"]) resultFilter := types.RunResultFromStringSlice(query["result"]) @@ -482,7 +485,7 @@ func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var runs []*types.Run var cgt *types.ChangeGroupsUpdateToken - err := h.readDB.Do(func(tx *db.Tx) error { + err := h.readDB.Do(ctx, func(tx *db.Tx) error { var err error runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder) if err != nil { diff --git a/internal/services/runservice/readdb/readdb.go b/internal/services/runservice/readdb/readdb.go index 405410c..e82ef63 100644 --- a/internal/services/runservice/readdb/readdb.go +++ b/internal/services/runservice/readdb/readdb.go @@ -119,7 +119,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd. } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return nil, err } @@ -150,7 +150,7 @@ func (r *ReadDB) IsInitialized() bool { // Initialize populates the readdb with the current etcd data and save the // revision to then feed it with the etcd events func (r *ReadDB) Initialize(ctx context.Context) error { - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return errors.Errorf("failed to reset db: %w", err) } if err := r.SyncObjectStorage(ctx); err != nil { @@ -162,7 +162,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error { return nil } -func (r *ReadDB) ResetDB() error { +func (r *ReadDB) ResetDB(ctx context.Context) error { // TODO(sgotti) this needs to be protected by a mutex r.rdb.Close() @@ -177,7 +177,7 @@ func (r *ReadDB) ResetDB() error { } // populate readdb - if err := rdb.Create(Stmts); err != nil { + if err := rdb.Create(ctx, Stmts); err != nil { return err } @@ -187,7 +187,7 @@ func (r *ReadDB) ResetDB() error { } func (r *ReadDB) SyncRDB(ctx context.Context) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { // Do pagination to limit the number of keys per request var revision int64 key := common.EtcdRunsDir @@ -256,7 +256,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error { } func (r *ReadDB) Run(ctx context.Context) error { - revision, err := r.GetRevision() + revision, err := r.GetRevision(ctx) if err != nil { return err } @@ -333,7 +333,7 @@ func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error { var revision int64 var lastRuns []*RunData - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) if err != nil { @@ -390,7 +390,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). r.dbWriteLock.Lock() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, ev := range wresp.Events { if err := r.handleEvent(tx, ev, &wresp); err != nil { return err @@ -503,7 +503,7 @@ func (r *ReadDB) handleChangeGroupEvent(tx *db.Tx, ev *etcdclientv3.Event, wresp func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { // get the last committed storage wal sequence saved in the rdb curWalSeq := "" - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error curWalSeq, err = r.GetCommittedWalSequenceOST(tx) if err != nil { @@ -557,12 +557,12 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { if doFullSync { r.log.Infof("doing a full sync from dump") - if err := r.ResetDB(); err != nil { + if err := r.ResetDB(ctx); err != nil { return err } var err error - curWalSeq, err = r.SyncFromDump() + curWalSeq, err = r.SyncFromDump(ctx) if err != nil { return err } @@ -575,7 +575,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { // since wals are first committed to objectstorage and then in etcd we would like to // avoid to store in rdb something that is not yet marked as committedstorage // in etcd - curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) + curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal) if err != nil { return errors.Errorf("failed to sync from wals: %w", err) } @@ -598,7 +598,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { } r.log.Infof("syncing from wals") - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := insertRevisionOST(tx, revision); err != nil { return err } @@ -645,7 +645,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { return err } -func (r *ReadDB) SyncFromDump() (string, error) { +func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) { dumpIndex, err := r.dm.GetLastDataStatus() if err != nil && err != ostypes.ErrNotExist { return "", err @@ -677,7 +677,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } dumpf.Close() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { for _, de := range dumpEntries { action := &datamanager.Action{ ActionType: datamanager.ActionTypePut, @@ -697,7 +697,7 @@ func (r *ReadDB) SyncFromDump() (string, error) { } } - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil { return err } @@ -710,9 +710,9 @@ func (r *ReadDB) SyncFromDump() (string, error) { return dumpIndex.WalSequence, nil } -func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { +func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) { insertfunc := func(walFiles []*datamanager.WalFile) error { - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { for _, walFile := range walFiles { walFilef, err := r.dm.ReadWal(walFile.WalSequence) if err != nil { @@ -771,7 +771,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) handleEventsOST(ctx context.Context) error { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) if err != nil { if err == sql.ErrNoRows { @@ -805,7 +805,7 @@ func (r *ReadDB) handleEventsOST(ctx context.Context) error { // a single transaction for every response (every response contains all the // events happened in an etcd revision). r.dbWriteLock.Lock() - err = r.rdb.Do(func(tx *db.Tx) error { + err = r.rdb.Do(ctx, func(tx *db.Tx) error { // if theres a wal seq epoch change something happened to etcd, usually (if // the user hasn't messed up with etcd keys) this means etcd has been reset @@ -949,11 +949,11 @@ func (r *ReadDB) handleWalEvent(tx *db.Tx, we *datamanager.WatchElement) error { return nil } -func (r *ReadDB) Do(f func(tx *db.Tx) error) error { +func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error { if !r.IsInitialized() { return errors.Errorf("db not initialized") } - return r.rdb.Do(f) + return r.rdb.Do(ctx, f) } func insertRevision(tx *db.Tx, revision int64) error { @@ -1073,10 +1073,10 @@ func insertChangeGroupRevision(tx *db.Tx, changegroupID string, revision int64) return nil } -func (r *ReadDB) GetRevision() (int64, error) { +func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) { var revision int64 - err := r.rdb.Do(func(tx *db.Tx) error { + err := r.rdb.Do(ctx, func(tx *db.Tx) error { var err error revision, err = r.getRevision(tx) return err