db: use context functions

Use the go sql context functions (ExecContext, QueryContext etc...)

The context is saved inside Tx so the library users should only pass it one time
to the db.Do function.
This commit is contained in:
Simone Gotti 2019-07-25 10:46:02 +02:00
parent 5c87d89c72
commit b8c2b4020a
23 changed files with 170 additions and 152 deletions

View File

@ -15,6 +15,7 @@
package db
import (
"context"
"database/sql"
sq "github.com/Masterminds/squirrel"
@ -27,10 +28,10 @@ const dbVersionTableDDLTmpl = `
const dbVersion = 1
func (db *DB) Create(stmts []string) error {
func (db *DB) Create(ctx context.Context, stmts []string) error {
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
err := db.Do(func(tx *Tx) error {
err := db.Do(ctx, func(tx *Tx) error {
if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil {
return errors.Errorf("failed to create dbversion table: %w", err)
}
@ -40,7 +41,7 @@ func (db *DB) Create(stmts []string) error {
return err
}
err = db.Do(func(tx *Tx) error {
err = db.Do(ctx, func(tx *Tx) error {
var version sql.NullInt64
q, args, err := sb.Select("max(version)").From("dbversion").ToSql()
if err != nil {

View File

@ -147,14 +147,15 @@ func NewDB(dbType Type, dbConnString string) (*DB, error) {
type Tx struct {
db *DB
tx *sql.Tx
ctx context.Context
}
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) Conn() (*sql.Conn, error) {
return db.db.Conn(context.TODO())
func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) {
return db.db.Conn(ctx)
}
func (db *DB) NewUnstartedTx() *Tx {
@ -163,19 +164,19 @@ func (db *DB) NewUnstartedTx() *Tx {
}
}
func (db *DB) NewTx() (*Tx, error) {
func (db *DB) NewTx(ctx context.Context) (*Tx, error) {
tx := db.NewUnstartedTx()
if err := tx.Start(); err != nil {
if err := tx.Start(ctx); err != nil {
return nil, err
}
return tx, nil
}
func (db *DB) Do(f func(tx *Tx) error) error {
func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error {
retries := 0
for {
err := db.do(f)
err := db.do(ctx, f)
if err != nil {
var sqerr sqlite3.Error
if errors.As(err, &sqerr) {
@ -191,8 +192,8 @@ func (db *DB) Do(f func(tx *Tx) error) error {
}
}
func (db *DB) do(f func(tx *Tx) error) error {
tx, err := db.NewTx()
func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error {
tx, err := db.NewTx(ctx)
if err != nil {
return err
}
@ -209,7 +210,7 @@ func (db *DB) do(f func(tx *Tx) error) error {
return tx.Commit()
}
func (tx *Tx) Start() error {
func (tx *Tx) Start(ctx context.Context) error {
wtx, err := tx.db.db.Begin()
if err != nil {
return err
@ -221,6 +222,7 @@ func (tx *Tx) Start() error {
}
}
tx.tx = wtx
tx.ctx = ctx
return nil
}
@ -240,19 +242,19 @@ func (tx *Tx) Rollback() error {
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
query = tx.db.data.translate(query)
r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...)
r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err
}
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = tx.db.data.translate(query)
r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...)
r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err
}
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
query = tx.db.data.translate(query)
return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...)
return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
}
func (tx *Tx) CurTime() (time.Time, error) {

View File

@ -44,7 +44,7 @@ func orgMemberResponse(orgUser *readdb.OrgUser) *OrgMemberResponse {
func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) {
var orgUsers []*readdb.OrgUser
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
org, err := h.readDB.GetOrg(tx, orgRef)
if err != nil {
@ -85,7 +85,7 @@ func (h *ActionHandler) CreateOrg(ctx context.Context, org *types.Organization)
cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)}
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil {
@ -182,7 +182,7 @@ func (h *ActionHandler) DeleteOrg(ctx context.Context, orgRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check org existance
org, err = h.readDB.GetOrgByName(tx, orgRef)
@ -232,7 +232,7 @@ func (h *ActionHandler) AddOrgMember(ctx context.Context, orgRef, userRef string
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check existing org
org, err = h.readDB.GetOrg(tx, orgRef)
@ -308,7 +308,7 @@ func (h *ActionHandler) RemoveOrgMember(ctx context.Context, orgRef, userRef str
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check existing org
org, err = h.readDB.GetOrg(tx, orgRef)

View File

@ -72,7 +72,7 @@ func (h *ActionHandler) CreateProject(ctx context.Context, project *types.Projec
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID)
if err != nil {
@ -167,7 +167,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check project exists
p, err := h.readDB.GetProject(tx, req.ProjectRef)
@ -269,7 +269,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check project existance

View File

@ -30,7 +30,7 @@ import (
func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) {
var projectGroups []*types.ProjectGroup
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil {
@ -53,7 +53,7 @@ func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGro
func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) {
var projects []*types.Project
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil {
@ -112,7 +112,7 @@ func (h *ActionHandler) CreateProjectGroup(ctx context.Context, projectGroup *ty
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID)
if err != nil {
return err
@ -184,7 +184,7 @@ func (h *ActionHandler) UpdateProjectGroup(ctx context.Context, req *UpdateProje
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check project exists
pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef)
@ -280,7 +280,7 @@ func (h *ActionHandler) DeleteProjectGroup(ctx context.Context, projectGroupRef
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check project group existance

View File

@ -74,7 +74,7 @@ func (h *ActionHandler) CreateRemoteSource(ctx context.Context, remoteSource *ty
cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)}
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil {
@ -129,7 +129,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check remotesource exists
@ -190,7 +190,7 @@ func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check remoteSource existance

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) {
var secret *types.Secret
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
secret, err = h.readDB.GetSecretByID(tx, secretID)
return err
@ -47,7 +47,7 @@ func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.
func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) {
var secrets []*types.Secret
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil {
return err
@ -105,7 +105,7 @@ func (h *ActionHandler) CreateSecret(ctx context.Context, secret *types.Secret)
cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)}
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil {
@ -168,7 +168,7 @@ func (h *ActionHandler) UpdateSecret(ctx context.Context, req *UpdateSecretReque
// changegroup is the secret name
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID)
@ -238,7 +238,7 @@ func (h *ActionHandler) DeleteSecret(ctx context.Context, parentType types.Confi
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil {

View File

@ -50,7 +50,7 @@ func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest)
var rs *types.RemoteSource
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil {
@ -156,7 +156,7 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
// check user existance
@ -206,7 +206,7 @@ func (h *ActionHandler) UpdateUser(ctx context.Context, req *UpdateUserRequest)
var user *types.User
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil {
@ -289,7 +289,7 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil {
@ -374,7 +374,7 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string)
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, userRef)
if err != nil {
@ -444,7 +444,7 @@ func (h *ActionHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil {
@ -518,7 +518,7 @@ func (h *ActionHandler) CreateUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, userRef)
if err != nil {
@ -583,7 +583,7 @@ func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, userRef)
if err != nil {
@ -644,7 +644,7 @@ func userOrgsResponse(userOrg *readdb.UserOrg) *UserOrgsResponse {
func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) {
var userOrgs []*readdb.UserOrg
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err := h.readDB.GetUser(tx, userRef)
if err != nil {

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) {
var variables []*types.Variable
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil {
return err
@ -81,7 +81,7 @@ func (h *ActionHandler) CreateVariable(ctx context.Context, variable *types.Vari
cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)}
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil {
@ -144,7 +144,7 @@ func (h *ActionHandler) UpdateVariable(ctx context.Context, req *UpdateVariableR
// changegroup is the variable name
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID)
@ -214,7 +214,7 @@ func (h *ActionHandler) DeleteVariable(ctx context.Context, parentType types.Con
var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil {

View File

@ -40,11 +40,12 @@ func NewOrgHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgHandler {
}
func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
orgRef := vars["orgref"]
var org *types.Organization
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
org, err = h.readDB.GetOrg(tx, orgRef)
return err
@ -105,7 +106,6 @@ func NewDeleteOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *DeleteOr
}
func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.log.Infof("deleteorghandler")
ctx := r.Context()
vars := mux.Vars(r)
@ -136,6 +136,7 @@ func NewOrgsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgsHandler {
}
func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
limitS := query.Get("limit")
@ -163,7 +164,7 @@ func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := query.Get("start")
var orgs []*types.Organization
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
orgs, err = h.readDB.GetOrgs(tx, start, limit, asc)
return err

View File

@ -15,6 +15,7 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/url"
@ -43,18 +44,18 @@ type Project struct {
GlobalVisibility types.Visibility
}
func projectResponse(readDB *readdb.ReadDB, project *types.Project) (*Project, error) {
r, err := projectsResponse(readDB, []*types.Project{project})
func projectResponse(ctx context.Context, readDB *readdb.ReadDB, project *types.Project) (*Project, error) {
r, err := projectsResponse(ctx, readDB, []*types.Project{project})
if err != nil {
return nil, err
}
return r[0], nil
}
func projectsResponse(readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) {
func projectsResponse(ctx context.Context, readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) {
resProjects := make([]*Project, len(projects))
err := readDB.Do(func(tx *db.Tx) error {
err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, project := range projects {
pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID)
if err != nil {
@ -132,6 +133,7 @@ func NewProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectHandle
}
func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
projectRef, err := url.PathUnescape(vars["projectref"])
if err != nil {
@ -140,7 +142,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
var project *types.Project
err = h.readDB.Do(func(tx *db.Tx) error {
err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
project, err = h.readDB.GetProject(tx, projectRef)
return err
@ -156,7 +158,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
resProject, err := projectResponse(h.readDB, project)
resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -193,7 +195,7 @@ func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
resProject, err := projectResponse(h.readDB, project)
resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -241,7 +243,7 @@ func (h *UpdateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
resProject, err := projectResponse(h.readDB, project)
resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return

View File

@ -15,6 +15,7 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/url"
@ -43,18 +44,18 @@ type ProjectGroup struct {
GlobalVisibility types.Visibility
}
func projectGroupResponse(readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) {
r, err := projectGroupsResponse(readDB, []*types.ProjectGroup{projectGroup})
func projectGroupResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) {
r, err := projectGroupsResponse(ctx, readDB, []*types.ProjectGroup{projectGroup})
if err != nil {
return nil, err
}
return r[0], nil
}
func projectGroupsResponse(readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) {
func projectGroupsResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) {
resProjectGroups := make([]*ProjectGroup, len(projectGroups))
err := readDB.Do(func(tx *db.Tx) error {
err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, projectGroup := range projectGroups {
pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID)
if err != nil {
@ -100,7 +101,9 @@ func NewProjectGroupHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectG
}
func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil {
httpError(w, util.NewErrBadRequest(err))
@ -108,7 +111,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
}
var projectGroup *types.ProjectGroup
err = h.readDB.Do(func(tx *db.Tx) error {
err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef)
return err
@ -124,7 +127,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup)
resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -148,6 +151,7 @@ func NewProjectGroupProjectsHandler(logger *zap.Logger, ah *action.ActionHandler
func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil {
httpError(w, util.NewErrBadRequest(err))
@ -160,7 +164,7 @@ func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.R
return
}
resProjects, err := projectsResponse(h.readDB, projects)
resProjects, err := projectsResponse(ctx, h.readDB, projects)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -196,7 +200,7 @@ func (h *ProjectGroupSubgroupsHandler) ServeHTTP(w http.ResponseWriter, r *http.
return
}
resProjectGroups, err := projectGroupsResponse(h.readDB, projectGroups)
resProjectGroups, err := projectGroupsResponse(ctx, h.readDB, projectGroups)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -233,7 +237,7 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return
}
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup)
resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return
@ -281,7 +285,7 @@ func (h *UpdateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return
}
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup)
resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) {
h.log.Errorf("err: %+v", err)
return

View File

@ -40,11 +40,12 @@ func NewRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSo
}
func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
rsRef := vars["remotesourceref"]
var remoteSource *types.RemoteSource
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef)
return err
@ -171,6 +172,7 @@ func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteS
}
func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
limitS := query.Get("limit")
@ -197,7 +199,7 @@ func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
start := query.Get("start")
remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc)
remoteSources, err := h.readDB.GetRemoteSources(ctx, start, limit, asc)
if err != nil {
h.log.Errorf("err: %+v", err)
httpError(w, err)

View File

@ -94,7 +94,7 @@ func (h *SecretsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resSecrets[i] = &Secret{Secret: s}
}
err = h.readDB.Do(func(tx *db.Tx) error {
err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path
for _, s := range resSecrets {
pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID)

View File

@ -41,11 +41,12 @@ func NewUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UserHandler {
}
func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
userRef := vars["userref"]
var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUser(tx, userRef)
return err
@ -198,6 +199,7 @@ func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler {
}
func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
limitS := query.Get("limit")
@ -232,7 +234,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bytoken":
token := query.Get("token")
var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUserByTokenValue(tx, token)
return err
@ -250,7 +252,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bylinkedaccount":
linkedAccountID := query.Get("linkedaccountid")
var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID)
return err
@ -269,7 +271,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
remoteUserID := query.Get("remoteuserid")
remoteSourceID := query.Get("remotesourceid")
var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID)
return err
@ -286,7 +288,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
users = []*types.User{user}
default:
// default query
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
users, err = h.readDB.GetUsers(tx, start, limit, asc)
return err

View File

@ -67,7 +67,7 @@ func (h *VariablesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for i, v := range variables {
resVariables[i] = &Variable{Variable: v}
}
err = h.readDB.Do(func(tx *db.Tx) error {
err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path
for _, v := range resVariables {
pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID)

View File

@ -99,9 +99,9 @@ func setupConfigstore(t *testing.T, ctx context.Context, dir string) (*Configsto
return cs, tetcd
}
func getProjects(cs *Configstore) ([]*types.Project, error) {
func getProjects(ctx context.Context, cs *Configstore) ([]*types.Project, error) {
var projects []*types.Project
err := cs.readDB.Do(func(tx *db.Tx) error {
err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
projects, err = cs.readDB.GetAllProjects(tx)
return err
@ -109,9 +109,9 @@ func getProjects(cs *Configstore) ([]*types.Project, error) {
return projects, err
}
func getUsers(cs *Configstore) ([]*types.User, error) {
func getUsers(ctx context.Context, cs *Configstore) ([]*types.User, error) {
var users []*types.User
err := cs.readDB.Do(func(tx *db.Tx) error {
err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
users, err = cs.readDB.GetUsers(tx, "", 0, true)
return err
@ -242,12 +242,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second)
users1, err := getUsers(cs1)
users1, err := getUsers(ctx, cs1)
if err != nil {
t.Fatalf("err: %v", err)
}
users2, err := getUsers(cs2)
users2, err := getUsers(ctx, cs2)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -275,12 +275,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second)
users1, err = getUsers(cs1)
users1, err = getUsers(ctx, cs1)
if err != nil {
t.Fatalf("err: %v", err)
}
users3, err := getUsers(cs3)
users3, err := getUsers(ctx, cs3)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -349,7 +349,7 @@ func TestUser(t *testing.T) {
}
})
t.Run("concurrent user with same name creation", func(t *testing.T) {
prevUsers, err := getUsers(cs)
prevUsers, err := getUsers(ctx, cs)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
@ -364,7 +364,7 @@ func TestUser(t *testing.T) {
time.Sleep(5 * time.Second)
users, err := getUsers(cs)
users, err := getUsers(ctx, cs)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
@ -494,7 +494,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
})
t.Run("concurrent project with same name creation", func(t *testing.T) {
prevProjects, err := getProjects(cs)
prevProjects, err := getProjects(ctx, cs)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
@ -511,7 +511,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
time.Sleep(1 * time.Second)
projects, err := getProjects(cs)
projects, err := getProjects(ctx, cs)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}

View File

@ -74,7 +74,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
}
// populate readdb
if err := rdb.Create(Stmts); err != nil {
if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err
}
@ -112,7 +112,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil
}
func (r *ReadDB) ResetDB() error {
func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close()
@ -127,7 +127,7 @@ func (r *ReadDB) ResetDB() error {
}
// populate readdb
if err := rdb.Create(Stmts); err != nil {
if err := rdb.Create(ctx, Stmts); err != nil {
return err
}
@ -136,7 +136,7 @@ func (r *ReadDB) ResetDB() error {
return nil
}
func (r *ReadDB) SyncFromDump() (string, error) {
func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist {
return "", err
@ -168,7 +168,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
}
dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries {
action := &datamanager.Action{
ActionType: datamanager.ActionTypePut,
@ -188,7 +188,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
}
}
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil {
return err
}
@ -201,9 +201,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil
}
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil {
@ -263,7 +263,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) SyncRDB(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb
curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error
curWalSeq, err = r.GetCommittedWalSequence(tx)
if err != nil {
@ -317,12 +317,12 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
if doFullSync {
r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil {
if err := r.ResetDB(ctx); err != nil {
return err
}
var err error
curWalSeq, err = r.SyncFromDump()
curWalSeq, err = r.SyncFromDump(ctx)
if err != nil {
return err
}
@ -335,7 +335,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
// etcd since wals are first committed to objectstorage and then in etcd we
// would like to avoid to store in rdb something that is not yet marked as
// committedstorage in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal)
curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil {
return errors.Errorf("failed to sync from wals: %w", err)
}
@ -358,7 +358,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
}
r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertRevision(tx, revision); err != nil {
return err
}
@ -406,7 +406,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
}
func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision()
revision, err := r.GetRevision(ctx)
if err != nil {
return err
}
@ -475,7 +475,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
// lost/reset
func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil {
if err == sql.ErrNoRows {
@ -508,7 +508,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the
// events happened in an etcd revision).
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset
@ -704,8 +704,8 @@ func (r *ReadDB) applyAction(tx *db.Tx, action *datamanager.Action) error {
return nil
}
func (r *ReadDB) Do(f func(tx *db.Tx) error) error {
return r.rdb.Do(f)
func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
return r.rdb.Do(ctx, f)
}
func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
@ -724,10 +724,10 @@ func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
return nil
}
func (r *ReadDB) GetRevision() (int64, error) {
func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error
revision, err = r.getRevision(tx)
return err

View File

@ -15,6 +15,7 @@
package readdb
import (
"context"
"database/sql"
"encoding/json"
@ -139,7 +140,7 @@ func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc
return s
}
func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) {
func (r *ReadDB) GetRemoteSources(ctx context.Context, startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) {
var remoteSources []*types.RemoteSource
s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc)
@ -149,7 +150,7 @@ func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc b
return nil, errors.Errorf("failed to build query: %w", err)
}
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
rows, err := tx.Query(q, args...)
if err != nil {
return err

View File

@ -381,7 +381,7 @@ func (h *ActionHandler) saveRun(ctx context.Context, rb *types.RunBundle, runcgt
run := rb.Run
rc := rb.Rc
c, cgt, err := h.getRunCounter(run.Group)
c, cgt, err := h.getRunCounter(ctx, run.Group)
h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt))
if err != nil {
return err
@ -570,7 +570,7 @@ func (h *ActionHandler) DeleteExecutor(ctx context.Context, executorID string) e
return nil
}
func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) {
func (h *ActionHandler) getRunCounter(ctx context.Context, group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) {
// use the first group dir after the root
pl := util.PathList(group)
if len(pl) < 2 {
@ -579,7 +579,7 @@ func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.Change
var c uint64
var cgt *datamanager.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
c, err = h.readDB.GetRunCounterOST(tx, pl[1])
if err != nil {

View File

@ -323,12 +323,13 @@ func NewChangeGroupsUpdateTokensHandler(logger *zap.Logger, readDB *readdb.ReadD
}
func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
groups := query["changegroup"]
var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups)
return err
@ -372,6 +373,7 @@ func NewRunHandler(logger *zap.Logger, e *etcd.Store, dm *datamanager.DataManage
}
func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r)
runID := vars["runid"]
@ -381,7 +383,7 @@ func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var run *types.Run
var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
run, err = h.readDB.GetRun(tx, runID)
if err != nil {
@ -447,6 +449,7 @@ func NewRunsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RunsHandler {
}
func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
phaseFilter := types.RunPhaseFromStringSlice(query["phase"])
resultFilter := types.RunResultFromStringSlice(query["result"])
@ -482,7 +485,7 @@ func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var runs []*types.Run
var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error {
err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error
runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder)
if err != nil {

View File

@ -119,7 +119,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
}
// populate readdb
if err := rdb.Create(Stmts); err != nil {
if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err
}
@ -150,7 +150,7 @@ func (r *ReadDB) IsInitialized() bool {
// Initialize populates the readdb with the current etcd data and save the
// revision to then feed it with the etcd events
func (r *ReadDB) Initialize(ctx context.Context) error {
if err := r.ResetDB(); err != nil {
if err := r.ResetDB(ctx); err != nil {
return errors.Errorf("failed to reset db: %w", err)
}
if err := r.SyncObjectStorage(ctx); err != nil {
@ -162,7 +162,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil
}
func (r *ReadDB) ResetDB() error {
func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close()
@ -177,7 +177,7 @@ func (r *ReadDB) ResetDB() error {
}
// populate readdb
if err := rdb.Create(Stmts); err != nil {
if err := rdb.Create(ctx, Stmts); err != nil {
return err
}
@ -187,7 +187,7 @@ func (r *ReadDB) ResetDB() error {
}
func (r *ReadDB) SyncRDB(ctx context.Context) error {
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
// Do pagination to limit the number of keys per request
var revision int64
key := common.EtcdRunsDir
@ -256,7 +256,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
}
func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision()
revision, err := r.GetRevision(ctx)
if err != nil {
return err
}
@ -333,7 +333,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64
var lastRuns []*RunData
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error
revision, err = r.getRevision(tx)
if err != nil {
@ -390,7 +390,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the
// events happened in an etcd revision).
r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, ev := range wresp.Events {
if err := r.handleEvent(tx, ev, &wresp); err != nil {
return err
@ -503,7 +503,7 @@ func (r *ReadDB) handleChangeGroupEvent(tx *db.Tx, ev *etcdclientv3.Event, wresp
func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb
curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error
curWalSeq, err = r.GetCommittedWalSequenceOST(tx)
if err != nil {
@ -557,12 +557,12 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
if doFullSync {
r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil {
if err := r.ResetDB(ctx); err != nil {
return err
}
var err error
curWalSeq, err = r.SyncFromDump()
curWalSeq, err = r.SyncFromDump(ctx)
if err != nil {
return err
}
@ -575,7 +575,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// since wals are first committed to objectstorage and then in etcd we would like to
// avoid to store in rdb something that is not yet marked as committedstorage
// in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal)
curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil {
return errors.Errorf("failed to sync from wals: %w", err)
}
@ -598,7 +598,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
}
r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := insertRevisionOST(tx, revision); err != nil {
return err
}
@ -645,7 +645,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
return err
}
func (r *ReadDB) SyncFromDump() (string, error) {
func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist {
return "", err
@ -677,7 +677,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
}
dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries {
action := &datamanager.Action{
ActionType: datamanager.ActionTypePut,
@ -697,7 +697,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
}
}
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil {
return err
}
@ -710,9 +710,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil
}
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil {
@ -771,7 +771,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) handleEventsOST(ctx context.Context) error {
var revision int64
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil {
if err == sql.ErrNoRows {
@ -805,7 +805,7 @@ func (r *ReadDB) handleEventsOST(ctx context.Context) error {
// a single transaction for every response (every response contains all the
// events happened in an etcd revision).
r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error {
err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset
@ -949,11 +949,11 @@ func (r *ReadDB) handleWalEvent(tx *db.Tx, we *datamanager.WatchElement) error {
return nil
}
func (r *ReadDB) Do(f func(tx *db.Tx) error) error {
func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
if !r.IsInitialized() {
return errors.Errorf("db not initialized")
}
return r.rdb.Do(f)
return r.rdb.Do(ctx, f)
}
func insertRevision(tx *db.Tx, revision int64) error {
@ -1073,10 +1073,10 @@ func insertChangeGroupRevision(tx *db.Tx, changegroupID string, revision int64)
return nil
}
func (r *ReadDB) GetRevision() (int64, error) {
func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64
err := r.rdb.Do(func(tx *db.Tx) error {
err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error
revision, err = r.getRevision(tx)
return err