db: use context functions

Use the go sql context functions (ExecContext, QueryContext etc...)

The context is saved inside Tx so the library users should only pass it one time
to the db.Do function.
This commit is contained in:
Simone Gotti 2019-07-25 10:46:02 +02:00
parent 5c87d89c72
commit b8c2b4020a
23 changed files with 170 additions and 152 deletions

View File

@ -15,6 +15,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
@ -27,10 +28,10 @@ const dbVersionTableDDLTmpl = `
const dbVersion = 1 const dbVersion = 1
func (db *DB) Create(stmts []string) error { func (db *DB) Create(ctx context.Context, stmts []string) error {
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
err := db.Do(func(tx *Tx) error { err := db.Do(ctx, func(tx *Tx) error {
if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil { if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil {
return errors.Errorf("failed to create dbversion table: %w", err) return errors.Errorf("failed to create dbversion table: %w", err)
} }
@ -40,7 +41,7 @@ func (db *DB) Create(stmts []string) error {
return err return err
} }
err = db.Do(func(tx *Tx) error { err = db.Do(ctx, func(tx *Tx) error {
var version sql.NullInt64 var version sql.NullInt64
q, args, err := sb.Select("max(version)").From("dbversion").ToSql() q, args, err := sb.Select("max(version)").From("dbversion").ToSql()
if err != nil { if err != nil {

View File

@ -145,16 +145,17 @@ func NewDB(dbType Type, dbConnString string) (*DB, error) {
// sql driver doesn't support concurrent statements on the same // sql driver doesn't support concurrent statements on the same
// connection/transaction) // connection/transaction)
type Tx struct { type Tx struct {
db *DB db *DB
tx *sql.Tx tx *sql.Tx
ctx context.Context
} }
func (db *DB) Close() error { func (db *DB) Close() error {
return db.db.Close() return db.db.Close()
} }
func (db *DB) Conn() (*sql.Conn, error) { func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) {
return db.db.Conn(context.TODO()) return db.db.Conn(ctx)
} }
func (db *DB) NewUnstartedTx() *Tx { func (db *DB) NewUnstartedTx() *Tx {
@ -163,19 +164,19 @@ func (db *DB) NewUnstartedTx() *Tx {
} }
} }
func (db *DB) NewTx() (*Tx, error) { func (db *DB) NewTx(ctx context.Context) (*Tx, error) {
tx := db.NewUnstartedTx() tx := db.NewUnstartedTx()
if err := tx.Start(); err != nil { if err := tx.Start(ctx); err != nil {
return nil, err return nil, err
} }
return tx, nil return tx, nil
} }
func (db *DB) Do(f func(tx *Tx) error) error { func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error {
retries := 0 retries := 0
for { for {
err := db.do(f) err := db.do(ctx, f)
if err != nil { if err != nil {
var sqerr sqlite3.Error var sqerr sqlite3.Error
if errors.As(err, &sqerr) { if errors.As(err, &sqerr) {
@ -191,8 +192,8 @@ func (db *DB) Do(f func(tx *Tx) error) error {
} }
} }
func (db *DB) do(f func(tx *Tx) error) error { func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error {
tx, err := db.NewTx() tx, err := db.NewTx(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -209,7 +210,7 @@ func (db *DB) do(f func(tx *Tx) error) error {
return tx.Commit() return tx.Commit()
} }
func (tx *Tx) Start() error { func (tx *Tx) Start(ctx context.Context) error {
wtx, err := tx.db.db.Begin() wtx, err := tx.db.db.Begin()
if err != nil { if err != nil {
return err return err
@ -221,6 +222,7 @@ func (tx *Tx) Start() error {
} }
} }
tx.tx = wtx tx.tx = wtx
tx.ctx = ctx
return nil return nil
} }
@ -240,19 +242,19 @@ func (tx *Tx) Rollback() error {
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...) r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err return r, err
} }
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...) r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err return r, err
} }
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...) return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
} }
func (tx *Tx) CurTime() (time.Time, error) { func (tx *Tx) CurTime() (time.Time, error) {

View File

@ -44,7 +44,7 @@ func orgMemberResponse(orgUser *readdb.OrgUser) *OrgMemberResponse {
func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) { func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) {
var orgUsers []*readdb.OrgUser var orgUsers []*readdb.OrgUser
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
org, err := h.readDB.GetOrg(tx, orgRef) org, err := h.readDB.GetOrg(tx, orgRef)
if err != nil { if err != nil {
@ -85,7 +85,7 @@ func (h *ActionHandler) CreateOrg(ctx context.Context, org *types.Organization)
cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)} cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func (h *ActionHandler) DeleteOrg(ctx context.Context, orgRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check org existance // check org existance
org, err = h.readDB.GetOrgByName(tx, orgRef) org, err = h.readDB.GetOrgByName(tx, orgRef)
@ -232,7 +232,7 @@ func (h *ActionHandler) AddOrgMember(ctx context.Context, orgRef, userRef string
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check existing org // check existing org
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)
@ -308,7 +308,7 @@ func (h *ActionHandler) RemoveOrgMember(ctx context.Context, orgRef, userRef str
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check existing org // check existing org
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)

View File

@ -72,7 +72,7 @@ func (h *ActionHandler) CreateProject(ctx context.Context, project *types.Projec
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID) group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID)
if err != nil { if err != nil {
@ -167,7 +167,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project exists // check project exists
p, err := h.readDB.GetProject(tx, req.ProjectRef) p, err := h.readDB.GetProject(tx, req.ProjectRef)
@ -210,7 +210,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq
return err return err
} }
if ap != nil { if ap != nil {
return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name,pp)) return util.NewErrBadRequest(errors.Errorf("project with name %q, path %q already exists", req.Project.Name, pp))
} }
} }
@ -269,7 +269,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project existance // check project existance

View File

@ -30,7 +30,7 @@ import (
func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) { func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) {
var projectGroups []*types.ProjectGroup var projectGroups []*types.ProjectGroup
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGro
func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) { func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) {
var projects []*types.Project var projects []*types.Project
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil { if err != nil {
@ -112,7 +112,7 @@ func (h *ActionHandler) CreateProjectGroup(ctx context.Context, projectGroup *ty
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID) parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID)
if err != nil { if err != nil {
return err return err
@ -184,7 +184,7 @@ func (h *ActionHandler) UpdateProjectGroup(ctx context.Context, req *UpdateProje
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project exists // check project exists
pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef) pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef)
@ -280,7 +280,7 @@ func (h *ActionHandler) DeleteProjectGroup(ctx context.Context, projectGroupRef
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project group existance // check project group existance

View File

@ -74,7 +74,7 @@ func (h *ActionHandler) CreateRemoteSource(ctx context.Context, remoteSource *ty
cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)} cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -129,7 +129,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check remotesource exists // check remotesource exists
@ -190,7 +190,7 @@ func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check remoteSource existance // check remoteSource existance

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) { func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) {
var secret *types.Secret var secret *types.Secret
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
secret, err = h.readDB.GetSecretByID(tx, secretID) secret, err = h.readDB.GetSecretByID(tx, secretID)
return err return err
@ -47,7 +47,7 @@ func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.
func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) { func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) {
var secrets []*types.Secret var secrets []*types.Secret
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {
return err return err
@ -105,7 +105,7 @@ func (h *ActionHandler) CreateSecret(ctx context.Context, secret *types.Secret)
cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)} cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -168,7 +168,7 @@ func (h *ActionHandler) UpdateSecret(ctx context.Context, req *UpdateSecretReque
// changegroup is the secret name // changegroup is the secret name
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID) parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID)
@ -238,7 +238,7 @@ func (h *ActionHandler) DeleteSecret(ctx context.Context, parentType types.Confi
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {

View File

@ -50,7 +50,7 @@ func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest)
var rs *types.RemoteSource var rs *types.RemoteSource
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -156,7 +156,7 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check user existance // check user existance
@ -206,7 +206,7 @@ func (h *ActionHandler) UpdateUser(ctx context.Context, req *UpdateUserRequest)
var user *types.User var user *types.User
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -289,7 +289,7 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -374,7 +374,7 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string)
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -444,7 +444,7 @@ func (h *ActionHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -518,7 +518,7 @@ func (h *ActionHandler) CreateUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -583,7 +583,7 @@ func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -644,7 +644,7 @@ func userOrgsResponse(userOrg *readdb.UserOrg) *UserOrgsResponse {
func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) { func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) {
var userOrgs []*readdb.UserOrg var userOrgs []*readdb.UserOrg
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err := h.readDB.GetUser(tx, userRef) user, err := h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) { func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) {
var variables []*types.Variable var variables []*types.Variable
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {
return err return err
@ -81,7 +81,7 @@ func (h *ActionHandler) CreateVariable(ctx context.Context, variable *types.Vari
cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)} cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -144,7 +144,7 @@ func (h *ActionHandler) UpdateVariable(ctx context.Context, req *UpdateVariableR
// changegroup is the variable name // changegroup is the variable name
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID) parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID)
@ -214,7 +214,7 @@ func (h *ActionHandler) DeleteVariable(ctx context.Context, parentType types.Con
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {

View File

@ -40,11 +40,12 @@ func NewOrgHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgHandler {
} }
func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
orgRef := vars["orgref"] orgRef := vars["orgref"]
var org *types.Organization var org *types.Organization
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)
return err return err
@ -105,7 +106,6 @@ func NewDeleteOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *DeleteOr
} }
func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.log.Infof("deleteorghandler")
ctx := r.Context() ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
@ -136,6 +136,7 @@ func NewOrgsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgsHandler {
} }
func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -163,7 +164,7 @@ func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := query.Get("start") start := query.Get("start")
var orgs []*types.Organization var orgs []*types.Organization
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
orgs, err = h.readDB.GetOrgs(tx, start, limit, asc) orgs, err = h.readDB.GetOrgs(tx, start, limit, asc)
return err return err

View File

@ -15,6 +15,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@ -43,18 +44,18 @@ type Project struct {
GlobalVisibility types.Visibility GlobalVisibility types.Visibility
} }
func projectResponse(readDB *readdb.ReadDB, project *types.Project) (*Project, error) { func projectResponse(ctx context.Context, readDB *readdb.ReadDB, project *types.Project) (*Project, error) {
r, err := projectsResponse(readDB, []*types.Project{project}) r, err := projectsResponse(ctx, readDB, []*types.Project{project})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r[0], nil return r[0], nil
} }
func projectsResponse(readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { func projectsResponse(ctx context.Context, readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) {
resProjects := make([]*Project, len(projects)) resProjects := make([]*Project, len(projects))
err := readDB.Do(func(tx *db.Tx) error { err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, project := range projects { for i, project := range projects {
pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID) pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID)
if err != nil { if err != nil {
@ -132,6 +133,7 @@ func NewProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectHandle
} }
func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectRef, err := url.PathUnescape(vars["projectref"]) projectRef, err := url.PathUnescape(vars["projectref"])
if err != nil { if err != nil {
@ -140,7 +142,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
var project *types.Project var project *types.Project
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
project, err = h.readDB.GetProject(tx, projectRef) project, err = h.readDB.GetProject(tx, projectRef)
return err return err
@ -156,7 +158,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -193,7 +195,7 @@ func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -241,7 +243,7 @@ func (h *UpdateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return

View File

@ -15,6 +15,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@ -43,18 +44,18 @@ type ProjectGroup struct {
GlobalVisibility types.Visibility GlobalVisibility types.Visibility
} }
func projectGroupResponse(readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { func projectGroupResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) {
r, err := projectGroupsResponse(readDB, []*types.ProjectGroup{projectGroup}) r, err := projectGroupsResponse(ctx, readDB, []*types.ProjectGroup{projectGroup})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r[0], nil return r[0], nil
} }
func projectGroupsResponse(readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { func projectGroupsResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) {
resProjectGroups := make([]*ProjectGroup, len(projectGroups)) resProjectGroups := make([]*ProjectGroup, len(projectGroups))
err := readDB.Do(func(tx *db.Tx) error { err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, projectGroup := range projectGroups { for i, projectGroup := range projectGroups {
pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID) pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID)
if err != nil { if err != nil {
@ -100,7 +101,9 @@ func NewProjectGroupHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectG
} }
func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil { if err != nil {
httpError(w, util.NewErrBadRequest(err)) httpError(w, util.NewErrBadRequest(err))
@ -108,7 +111,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
} }
var projectGroup *types.ProjectGroup var projectGroup *types.ProjectGroup
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef)
return err return err
@ -124,7 +127,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -148,6 +151,7 @@ func NewProjectGroupProjectsHandler(logger *zap.Logger, ah *action.ActionHandler
func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil { if err != nil {
httpError(w, util.NewErrBadRequest(err)) httpError(w, util.NewErrBadRequest(err))
@ -160,7 +164,7 @@ func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.R
return return
} }
resProjects, err := projectsResponse(h.readDB, projects) resProjects, err := projectsResponse(ctx, h.readDB, projects)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -196,7 +200,7 @@ func (h *ProjectGroupSubgroupsHandler) ServeHTTP(w http.ResponseWriter, r *http.
return return
} }
resProjectGroups, err := projectGroupsResponse(h.readDB, projectGroups) resProjectGroups, err := projectGroupsResponse(ctx, h.readDB, projectGroups)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -233,7 +237,7 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -281,7 +285,7 @@ func (h *UpdateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return

View File

@ -40,11 +40,12 @@ func NewRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSo
} }
func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
rsRef := vars["remotesourceref"] rsRef := vars["remotesourceref"]
var remoteSource *types.RemoteSource var remoteSource *types.RemoteSource
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef) remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef)
return err return err
@ -171,6 +172,7 @@ func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteS
} }
func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -197,7 +199,7 @@ func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
start := query.Get("start") start := query.Get("start")
remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc) remoteSources, err := h.readDB.GetRemoteSources(ctx, start, limit, asc)
if err != nil { if err != nil {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
httpError(w, err) httpError(w, err)

View File

@ -94,7 +94,7 @@ func (h *SecretsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resSecrets[i] = &Secret{Secret: s} resSecrets[i] = &Secret{Secret: s}
} }
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path // populate parent path
for _, s := range resSecrets { for _, s := range resSecrets {
pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID) pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID)

View File

@ -41,11 +41,12 @@ func NewUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UserHandler {
} }
func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
userRef := vars["userref"] userRef := vars["userref"]
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
return err return err
@ -198,6 +199,7 @@ func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler {
} }
func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -232,7 +234,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bytoken": case "bytoken":
token := query.Get("token") token := query.Get("token")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByTokenValue(tx, token) user, err = h.readDB.GetUserByTokenValue(tx, token)
return err return err
@ -250,7 +252,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bylinkedaccount": case "bylinkedaccount":
linkedAccountID := query.Get("linkedaccountid") linkedAccountID := query.Get("linkedaccountid")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID) user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID)
return err return err
@ -269,7 +271,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
remoteUserID := query.Get("remoteuserid") remoteUserID := query.Get("remoteuserid")
remoteSourceID := query.Get("remotesourceid") remoteSourceID := query.Get("remotesourceid")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID) user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID)
return err return err
@ -286,7 +288,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
users = []*types.User{user} users = []*types.User{user}
default: default:
// default query // default query
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
users, err = h.readDB.GetUsers(tx, start, limit, asc) users, err = h.readDB.GetUsers(tx, start, limit, asc)
return err return err

View File

@ -67,7 +67,7 @@ func (h *VariablesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for i, v := range variables { for i, v := range variables {
resVariables[i] = &Variable{Variable: v} resVariables[i] = &Variable{Variable: v}
} }
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path // populate parent path
for _, v := range resVariables { for _, v := range resVariables {
pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID) pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID)

View File

@ -99,9 +99,9 @@ func setupConfigstore(t *testing.T, ctx context.Context, dir string) (*Configsto
return cs, tetcd return cs, tetcd
} }
func getProjects(cs *Configstore) ([]*types.Project, error) { func getProjects(ctx context.Context, cs *Configstore) ([]*types.Project, error) {
var projects []*types.Project var projects []*types.Project
err := cs.readDB.Do(func(tx *db.Tx) error { err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projects, err = cs.readDB.GetAllProjects(tx) projects, err = cs.readDB.GetAllProjects(tx)
return err return err
@ -109,9 +109,9 @@ func getProjects(cs *Configstore) ([]*types.Project, error) {
return projects, err return projects, err
} }
func getUsers(cs *Configstore) ([]*types.User, error) { func getUsers(ctx context.Context, cs *Configstore) ([]*types.User, error) {
var users []*types.User var users []*types.User
err := cs.readDB.Do(func(tx *db.Tx) error { err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
users, err = cs.readDB.GetUsers(tx, "", 0, true) users, err = cs.readDB.GetUsers(tx, "", 0, true)
return err return err
@ -242,12 +242,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users1, err := getUsers(cs1) users1, err := getUsers(ctx, cs1)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
users2, err := getUsers(cs2) users2, err := getUsers(ctx, cs2)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -275,12 +275,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users1, err = getUsers(cs1) users1, err = getUsers(ctx, cs1)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
users3, err := getUsers(cs3) users3, err := getUsers(ctx, cs3)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -349,7 +349,7 @@ func TestUser(t *testing.T) {
} }
}) })
t.Run("concurrent user with same name creation", func(t *testing.T) { t.Run("concurrent user with same name creation", func(t *testing.T) {
prevUsers, err := getUsers(cs) prevUsers, err := getUsers(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -364,7 +364,7 @@ func TestUser(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users, err := getUsers(cs) users, err := getUsers(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -494,7 +494,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
}) })
t.Run("concurrent project with same name creation", func(t *testing.T) { t.Run("concurrent project with same name creation", func(t *testing.T) {
prevProjects, err := getProjects(cs) prevProjects, err := getProjects(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -511,7 +511,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
projects, err := getProjects(cs) projects, err := getProjects(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }

View File

@ -74,7 +74,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err return nil, err
} }
@ -112,7 +112,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil return nil
} }
func (r *ReadDB) ResetDB() error { func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex // TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close() r.rdb.Close()
@ -127,7 +127,7 @@ func (r *ReadDB) ResetDB() error {
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return err return err
} }
@ -136,7 +136,7 @@ func (r *ReadDB) ResetDB() error {
return nil return nil
} }
func (r *ReadDB) SyncFromDump() (string, error) { func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus() dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist { if err != nil && err != ostypes.ErrNotExist {
return "", err return "", err
@ -168,7 +168,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
dumpf.Close() dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries { for _, de := range dumpEntries {
action := &datamanager.Action{ action := &datamanager.Action{
ActionType: datamanager.ActionTypePut, ActionType: datamanager.ActionTypePut,
@ -188,7 +188,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil { if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil {
return err return err
} }
@ -201,9 +201,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil return dumpIndex.WalSequence, nil
} }
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error { insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles { for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence) walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil { if err != nil {
@ -263,7 +263,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) SyncRDB(ctx context.Context) error { func (r *ReadDB) SyncRDB(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb // get the last committed storage wal sequence saved in the rdb
curWalSeq := "" curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
curWalSeq, err = r.GetCommittedWalSequence(tx) curWalSeq, err = r.GetCommittedWalSequence(tx)
if err != nil { if err != nil {
@ -317,12 +317,12 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
if doFullSync { if doFullSync {
r.log.Infof("doing a full sync from dump") r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return err return err
} }
var err error var err error
curWalSeq, err = r.SyncFromDump() curWalSeq, err = r.SyncFromDump(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -335,7 +335,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
// etcd since wals are first committed to objectstorage and then in etcd we // etcd since wals are first committed to objectstorage and then in etcd we
// would like to avoid to store in rdb something that is not yet marked as // would like to avoid to store in rdb something that is not yet marked as
// committedstorage in etcd // committedstorage in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil { if err != nil {
return errors.Errorf("failed to sync from wals: %w", err) return errors.Errorf("failed to sync from wals: %w", err)
} }
@ -358,7 +358,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
r.log.Infof("syncing from wals") r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertRevision(tx, revision); err != nil { if err := r.insertRevision(tx, revision); err != nil {
return err return err
} }
@ -406,7 +406,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision() revision, err := r.GetRevision(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -475,7 +475,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
// lost/reset // lost/reset
func (r *ReadDB) handleEvents(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -508,7 +508,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if // if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset // the user hasn't messed up with etcd keys) this means etcd has been reset
@ -704,8 +704,8 @@ func (r *ReadDB) applyAction(tx *db.Tx, action *datamanager.Action) error {
return nil return nil
} }
func (r *ReadDB) Do(f func(tx *db.Tx) error) error { func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
return r.rdb.Do(f) return r.rdb.Do(ctx, f)
} }
func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
@ -724,10 +724,10 @@ func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
return nil return nil
} }
func (r *ReadDB) GetRevision() (int64, error) { func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
return err return err

View File

@ -15,6 +15,7 @@
package readdb package readdb
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -139,7 +140,7 @@ func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc
return s return s
} }
func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { func (r *ReadDB) GetRemoteSources(ctx context.Context, startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) {
var remoteSources []*types.RemoteSource var remoteSources []*types.RemoteSource
s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc) s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc)
@ -149,7 +150,7 @@ func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc b
return nil, errors.Errorf("failed to build query: %w", err) return nil, errors.Errorf("failed to build query: %w", err)
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
rows, err := tx.Query(q, args...) rows, err := tx.Query(q, args...)
if err != nil { if err != nil {
return err return err

View File

@ -142,12 +142,12 @@ func (g *Gateway) Run(ctx context.Context) error {
return h return h
} }
if len(g.c.Web.AllowedOrigins) > 0 { if len(g.c.Web.AllowedOrigins) > 0 {
corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"}) corsAllowedMethodsOptions := ghandlers.AllowedMethods([]string{"GET", "HEAD", "POST", "PUT", "DELETE"})
corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"}) corsAllowedHeadersOptions := ghandlers.AllowedHeaders([]string{"Accept", "Accept-Encoding", "Authorization", "Content-Length", "Content-Type", "X-CSRF-Token", "Authorization"})
corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins) corsAllowedOriginsOptions := ghandlers.AllowedOrigins(g.c.Web.AllowedOrigins)
corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions) corsHandler = ghandlers.CORS(corsAllowedMethodsOptions, corsAllowedHeadersOptions, corsAllowedOriginsOptions)
} }
webhooksHandler := api.NewWebhooksHandler(logger, g.ah, g.configstoreClient, g.runserviceClient, g.c.APIExposedURL) webhooksHandler := api.NewWebhooksHandler(logger, g.ah, g.configstoreClient, g.runserviceClient, g.c.APIExposedURL)

View File

@ -381,7 +381,7 @@ func (h *ActionHandler) saveRun(ctx context.Context, rb *types.RunBundle, runcgt
run := rb.Run run := rb.Run
rc := rb.Rc rc := rb.Rc
c, cgt, err := h.getRunCounter(run.Group) c, cgt, err := h.getRunCounter(ctx, run.Group)
h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt)) h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt))
if err != nil { if err != nil {
return err return err
@ -570,7 +570,7 @@ func (h *ActionHandler) DeleteExecutor(ctx context.Context, executorID string) e
return nil return nil
} }
func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { func (h *ActionHandler) getRunCounter(ctx context.Context, group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) {
// use the first group dir after the root // use the first group dir after the root
pl := util.PathList(group) pl := util.PathList(group)
if len(pl) < 2 { if len(pl) < 2 {
@ -579,7 +579,7 @@ func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.Change
var c uint64 var c uint64
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
c, err = h.readDB.GetRunCounterOST(tx, pl[1]) c, err = h.readDB.GetRunCounterOST(tx, pl[1])
if err != nil { if err != nil {

View File

@ -323,12 +323,13 @@ func NewChangeGroupsUpdateTokensHandler(logger *zap.Logger, readDB *readdb.ReadD
} }
func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
groups := query["changegroup"] groups := query["changegroup"]
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups)
return err return err
@ -372,6 +373,7 @@ func NewRunHandler(logger *zap.Logger, e *etcd.Store, dm *datamanager.DataManage
} }
func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
runID := vars["runid"] runID := vars["runid"]
@ -381,7 +383,7 @@ func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var run *types.Run var run *types.Run
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
run, err = h.readDB.GetRun(tx, runID) run, err = h.readDB.GetRun(tx, runID)
if err != nil { if err != nil {
@ -447,6 +449,7 @@ func NewRunsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RunsHandler {
} }
func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
phaseFilter := types.RunPhaseFromStringSlice(query["phase"]) phaseFilter := types.RunPhaseFromStringSlice(query["phase"])
resultFilter := types.RunResultFromStringSlice(query["result"]) resultFilter := types.RunResultFromStringSlice(query["result"])
@ -482,7 +485,7 @@ func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var runs []*types.Run var runs []*types.Run
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder) runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder)
if err != nil { if err != nil {

View File

@ -119,7 +119,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err return nil, err
} }
@ -150,7 +150,7 @@ func (r *ReadDB) IsInitialized() bool {
// Initialize populates the readdb with the current etcd data and save the // Initialize populates the readdb with the current etcd data and save the
// revision to then feed it with the etcd events // revision to then feed it with the etcd events
func (r *ReadDB) Initialize(ctx context.Context) error { func (r *ReadDB) Initialize(ctx context.Context) error {
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return errors.Errorf("failed to reset db: %w", err) return errors.Errorf("failed to reset db: %w", err)
} }
if err := r.SyncObjectStorage(ctx); err != nil { if err := r.SyncObjectStorage(ctx); err != nil {
@ -162,7 +162,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil return nil
} }
func (r *ReadDB) ResetDB() error { func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex // TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close() r.rdb.Close()
@ -177,7 +177,7 @@ func (r *ReadDB) ResetDB() error {
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return err return err
} }
@ -187,7 +187,7 @@ func (r *ReadDB) ResetDB() error {
} }
func (r *ReadDB) SyncRDB(ctx context.Context) error { func (r *ReadDB) SyncRDB(ctx context.Context) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
// Do pagination to limit the number of keys per request // Do pagination to limit the number of keys per request
var revision int64 var revision int64
key := common.EtcdRunsDir key := common.EtcdRunsDir
@ -256,7 +256,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision() revision, err := r.GetRevision(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -333,7 +333,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
func (r *ReadDB) handleEvents(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64 var revision int64
var lastRuns []*RunData var lastRuns []*RunData
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
if err != nil { if err != nil {
@ -390,7 +390,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
r.dbWriteLock.Lock() r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
if err := r.handleEvent(tx, ev, &wresp); err != nil { if err := r.handleEvent(tx, ev, &wresp); err != nil {
return err return err
@ -503,7 +503,7 @@ func (r *ReadDB) handleChangeGroupEvent(tx *db.Tx, ev *etcdclientv3.Event, wresp
func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb // get the last committed storage wal sequence saved in the rdb
curWalSeq := "" curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
curWalSeq, err = r.GetCommittedWalSequenceOST(tx) curWalSeq, err = r.GetCommittedWalSequenceOST(tx)
if err != nil { if err != nil {
@ -557,12 +557,12 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
if doFullSync { if doFullSync {
r.log.Infof("doing a full sync from dump") r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return err return err
} }
var err error var err error
curWalSeq, err = r.SyncFromDump() curWalSeq, err = r.SyncFromDump(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -575,7 +575,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// since wals are first committed to objectstorage and then in etcd we would like to // since wals are first committed to objectstorage and then in etcd we would like to
// avoid to store in rdb something that is not yet marked as committedstorage // avoid to store in rdb something that is not yet marked as committedstorage
// in etcd // in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil { if err != nil {
return errors.Errorf("failed to sync from wals: %w", err) return errors.Errorf("failed to sync from wals: %w", err)
} }
@ -598,7 +598,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
} }
r.log.Infof("syncing from wals") r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := insertRevisionOST(tx, revision); err != nil { if err := insertRevisionOST(tx, revision); err != nil {
return err return err
} }
@ -645,7 +645,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
return err return err
} }
func (r *ReadDB) SyncFromDump() (string, error) { func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus() dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist { if err != nil && err != ostypes.ErrNotExist {
return "", err return "", err
@ -677,7 +677,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
dumpf.Close() dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries { for _, de := range dumpEntries {
action := &datamanager.Action{ action := &datamanager.Action{
ActionType: datamanager.ActionTypePut, ActionType: datamanager.ActionTypePut,
@ -697,7 +697,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil { if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil {
return err return err
} }
@ -710,9 +710,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil return dumpIndex.WalSequence, nil
} }
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error { insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles { for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence) walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil { if err != nil {
@ -771,7 +771,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) handleEventsOST(ctx context.Context) error { func (r *ReadDB) handleEventsOST(ctx context.Context) error {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -805,7 +805,7 @@ func (r *ReadDB) handleEventsOST(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
r.dbWriteLock.Lock() r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if // if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset // the user hasn't messed up with etcd keys) this means etcd has been reset
@ -949,11 +949,11 @@ func (r *ReadDB) handleWalEvent(tx *db.Tx, we *datamanager.WatchElement) error {
return nil return nil
} }
func (r *ReadDB) Do(f func(tx *db.Tx) error) error { func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
if !r.IsInitialized() { if !r.IsInitialized() {
return errors.Errorf("db not initialized") return errors.Errorf("db not initialized")
} }
return r.rdb.Do(f) return r.rdb.Do(ctx, f)
} }
func insertRevision(tx *db.Tx, revision int64) error { func insertRevision(tx *db.Tx, revision int64) error {
@ -1073,10 +1073,10 @@ func insertChangeGroupRevision(tx *db.Tx, changegroupID string, revision int64)
return nil return nil
} }
func (r *ReadDB) GetRevision() (int64, error) { func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
return err return err