Merge pull request #62 from sgotti/db_use_context

db: use context functions
This commit is contained in:
Simone Gotti 2019-07-25 16:22:10 +02:00 committed by GitHub
commit 9481b8d67d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 170 additions and 152 deletions

View File

@ -15,6 +15,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
@ -27,10 +28,10 @@ const dbVersionTableDDLTmpl = `
const dbVersion = 1 const dbVersion = 1
func (db *DB) Create(stmts []string) error { func (db *DB) Create(ctx context.Context, stmts []string) error {
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
err := db.Do(func(tx *Tx) error { err := db.Do(ctx, func(tx *Tx) error {
if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil { if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil {
return errors.Errorf("failed to create dbversion table: %w", err) return errors.Errorf("failed to create dbversion table: %w", err)
} }
@ -40,7 +41,7 @@ func (db *DB) Create(stmts []string) error {
return err return err
} }
err = db.Do(func(tx *Tx) error { err = db.Do(ctx, func(tx *Tx) error {
var version sql.NullInt64 var version sql.NullInt64
q, args, err := sb.Select("max(version)").From("dbversion").ToSql() q, args, err := sb.Select("max(version)").From("dbversion").ToSql()
if err != nil { if err != nil {

View File

@ -147,14 +147,15 @@ func NewDB(dbType Type, dbConnString string) (*DB, error) {
type Tx struct { type Tx struct {
db *DB db *DB
tx *sql.Tx tx *sql.Tx
ctx context.Context
} }
func (db *DB) Close() error { func (db *DB) Close() error {
return db.db.Close() return db.db.Close()
} }
func (db *DB) Conn() (*sql.Conn, error) { func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) {
return db.db.Conn(context.TODO()) return db.db.Conn(ctx)
} }
func (db *DB) NewUnstartedTx() *Tx { func (db *DB) NewUnstartedTx() *Tx {
@ -163,19 +164,19 @@ func (db *DB) NewUnstartedTx() *Tx {
} }
} }
func (db *DB) NewTx() (*Tx, error) { func (db *DB) NewTx(ctx context.Context) (*Tx, error) {
tx := db.NewUnstartedTx() tx := db.NewUnstartedTx()
if err := tx.Start(); err != nil { if err := tx.Start(ctx); err != nil {
return nil, err return nil, err
} }
return tx, nil return tx, nil
} }
func (db *DB) Do(f func(tx *Tx) error) error { func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error {
retries := 0 retries := 0
for { for {
err := db.do(f) err := db.do(ctx, f)
if err != nil { if err != nil {
var sqerr sqlite3.Error var sqerr sqlite3.Error
if errors.As(err, &sqerr) { if errors.As(err, &sqerr) {
@ -191,8 +192,8 @@ func (db *DB) Do(f func(tx *Tx) error) error {
} }
} }
func (db *DB) do(f func(tx *Tx) error) error { func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error {
tx, err := db.NewTx() tx, err := db.NewTx(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -209,7 +210,7 @@ func (db *DB) do(f func(tx *Tx) error) error {
return tx.Commit() return tx.Commit()
} }
func (tx *Tx) Start() error { func (tx *Tx) Start(ctx context.Context) error {
wtx, err := tx.db.db.Begin() wtx, err := tx.db.db.Begin()
if err != nil { if err != nil {
return err return err
@ -221,6 +222,7 @@ func (tx *Tx) Start() error {
} }
} }
tx.tx = wtx tx.tx = wtx
tx.ctx = ctx
return nil return nil
} }
@ -240,19 +242,19 @@ func (tx *Tx) Rollback() error {
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...) r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err return r, err
} }
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...) r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, err return r, err
} }
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
query = tx.db.data.translate(query) query = tx.db.data.translate(query)
return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...) return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
} }
func (tx *Tx) CurTime() (time.Time, error) { func (tx *Tx) CurTime() (time.Time, error) {

View File

@ -44,7 +44,7 @@ func orgMemberResponse(orgUser *readdb.OrgUser) *OrgMemberResponse {
func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) { func (h *ActionHandler) GetOrgMembers(ctx context.Context, orgRef string) ([]*OrgMemberResponse, error) {
var orgUsers []*readdb.OrgUser var orgUsers []*readdb.OrgUser
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
org, err := h.readDB.GetOrg(tx, orgRef) org, err := h.readDB.GetOrg(tx, orgRef)
if err != nil { if err != nil {
@ -85,7 +85,7 @@ func (h *ActionHandler) CreateOrg(ctx context.Context, org *types.Organization)
cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)} cgNames := []string{util.EncodeSha256Hex("orgname-" + org.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func (h *ActionHandler) DeleteOrg(ctx context.Context, orgRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check org existance // check org existance
org, err = h.readDB.GetOrgByName(tx, orgRef) org, err = h.readDB.GetOrgByName(tx, orgRef)
@ -232,7 +232,7 @@ func (h *ActionHandler) AddOrgMember(ctx context.Context, orgRef, userRef string
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check existing org // check existing org
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)
@ -308,7 +308,7 @@ func (h *ActionHandler) RemoveOrgMember(ctx context.Context, orgRef, userRef str
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check existing org // check existing org
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)

View File

@ -72,7 +72,7 @@ func (h *ActionHandler) CreateProject(ctx context.Context, project *types.Projec
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID) group, err := h.readDB.GetProjectGroup(tx, project.Parent.ID)
if err != nil { if err != nil {
@ -167,7 +167,7 @@ func (h *ActionHandler) UpdateProject(ctx context.Context, req *UpdateProjectReq
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project exists // check project exists
p, err := h.readDB.GetProject(tx, req.ProjectRef) p, err := h.readDB.GetProject(tx, req.ProjectRef)
@ -269,7 +269,7 @@ func (h *ActionHandler) DeleteProject(ctx context.Context, projectRef string) er
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project existance // check project existance

View File

@ -30,7 +30,7 @@ import (
func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) { func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGroupRef string) ([]*types.ProjectGroup, error) {
var projectGroups []*types.ProjectGroup var projectGroups []*types.ProjectGroup
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func (h *ActionHandler) GetProjectGroupSubgroups(ctx context.Context, projectGro
func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) { func (h *ActionHandler) GetProjectGroupProjects(ctx context.Context, projectGroupRef string) ([]*types.Project, error) {
var projects []*types.Project var projects []*types.Project
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err := h.readDB.GetProjectGroup(tx, projectGroupRef)
if err != nil { if err != nil {
@ -112,7 +112,7 @@ func (h *ActionHandler) CreateProjectGroup(ctx context.Context, projectGroup *ty
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID) parentProjectGroup, err := h.readDB.GetProjectGroup(tx, projectGroup.Parent.ID)
if err != nil { if err != nil {
return err return err
@ -184,7 +184,7 @@ func (h *ActionHandler) UpdateProjectGroup(ctx context.Context, req *UpdateProje
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project exists // check project exists
pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef) pg, err := h.readDB.GetProjectGroup(tx, req.ProjectGroupRef)
@ -280,7 +280,7 @@ func (h *ActionHandler) DeleteProjectGroup(ctx context.Context, projectGroupRef
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check project group existance // check project group existance

View File

@ -74,7 +74,7 @@ func (h *ActionHandler) CreateRemoteSource(ctx context.Context, remoteSource *ty
cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)} cgNames := []string{util.EncodeSha256Hex("remotesourcename-" + remoteSource.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -129,7 +129,7 @@ func (h *ActionHandler) UpdateRemoteSource(ctx context.Context, req *UpdateRemot
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check remotesource exists // check remotesource exists
@ -190,7 +190,7 @@ func (h *ActionHandler) DeleteRemoteSource(ctx context.Context, remoteSourceName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check remoteSource existance // check remoteSource existance

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) { func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.Secret, error) {
var secret *types.Secret var secret *types.Secret
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
secret, err = h.readDB.GetSecretByID(tx, secretID) secret, err = h.readDB.GetSecretByID(tx, secretID)
return err return err
@ -47,7 +47,7 @@ func (h *ActionHandler) GetSecret(ctx context.Context, secretID string) (*types.
func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) { func (h *ActionHandler) GetSecrets(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Secret, error) {
var secrets []*types.Secret var secrets []*types.Secret
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {
return err return err
@ -105,7 +105,7 @@ func (h *ActionHandler) CreateSecret(ctx context.Context, secret *types.Secret)
cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)} cgNames := []string{util.EncodeSha256Hex("secretname-" + secret.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -168,7 +168,7 @@ func (h *ActionHandler) UpdateSecret(ctx context.Context, req *UpdateSecretReque
// changegroup is the secret name // changegroup is the secret name
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID) parentID, err := h.readDB.ResolveConfigID(tx, req.Secret.Parent.Type, req.Secret.Parent.ID)
@ -238,7 +238,7 @@ func (h *ActionHandler) DeleteSecret(ctx context.Context, parentType types.Confi
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {

View File

@ -50,7 +50,7 @@ func (h *ActionHandler) CreateUser(ctx context.Context, req *CreateUserRequest)
var rs *types.RemoteSource var rs *types.RemoteSource
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -156,7 +156,7 @@ func (h *ActionHandler) DeleteUser(ctx context.Context, userRef string) error {
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
// check user existance // check user existance
@ -206,7 +206,7 @@ func (h *ActionHandler) UpdateUser(ctx context.Context, req *UpdateUserRequest)
var user *types.User var user *types.User
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -289,7 +289,7 @@ func (h *ActionHandler) CreateUserLA(ctx context.Context, req *CreateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -374,7 +374,7 @@ func (h *ActionHandler) DeleteUserLA(ctx context.Context, userRef, laID string)
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -444,7 +444,7 @@ func (h *ActionHandler) UpdateUserLA(ctx context.Context, req *UpdateUserLAReque
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, req.UserRef) user, err = h.readDB.GetUser(tx, req.UserRef)
if err != nil { if err != nil {
@ -518,7 +518,7 @@ func (h *ActionHandler) CreateUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -583,7 +583,7 @@ func (h *ActionHandler) DeleteUserToken(ctx context.Context, userRef, tokenName
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {
@ -644,7 +644,7 @@ func userOrgsResponse(userOrg *readdb.UserOrg) *UserOrgsResponse {
func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) { func (h *ActionHandler) GetUserOrgs(ctx context.Context, userRef string) ([]*UserOrgsResponse, error) {
var userOrgs []*readdb.UserOrg var userOrgs []*readdb.UserOrg
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err := h.readDB.GetUser(tx, userRef) user, err := h.readDB.GetUser(tx, userRef)
if err != nil { if err != nil {

View File

@ -29,7 +29,7 @@ import (
func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) { func (h *ActionHandler) GetVariables(ctx context.Context, parentType types.ConfigType, parentRef string, tree bool) ([]*types.Variable, error) {
var variables []*types.Variable var variables []*types.Variable
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {
return err return err
@ -81,7 +81,7 @@ func (h *ActionHandler) CreateVariable(ctx context.Context, variable *types.Vari
cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)} cgNames := []string{util.EncodeSha256Hex("variablename-" + variable.Name)}
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, cgNames)
if err != nil { if err != nil {
@ -144,7 +144,7 @@ func (h *ActionHandler) UpdateVariable(ctx context.Context, req *UpdateVariableR
// changegroup is the variable name // changegroup is the variable name
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID) parentID, err := h.readDB.ResolveConfigID(tx, req.Variable.Parent.Type, req.Variable.Parent.ID)
@ -214,7 +214,7 @@ func (h *ActionHandler) DeleteVariable(ctx context.Context, parentType types.Con
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
// must do all the checks in a single transaction to avoid concurrent changes // must do all the checks in a single transaction to avoid concurrent changes
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef) parentID, err := h.readDB.ResolveConfigID(tx, parentType, parentRef)
if err != nil { if err != nil {

View File

@ -40,11 +40,12 @@ func NewOrgHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgHandler {
} }
func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *OrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
orgRef := vars["orgref"] orgRef := vars["orgref"]
var org *types.Organization var org *types.Organization
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
org, err = h.readDB.GetOrg(tx, orgRef) org, err = h.readDB.GetOrg(tx, orgRef)
return err return err
@ -105,7 +106,6 @@ func NewDeleteOrgHandler(logger *zap.Logger, ah *action.ActionHandler) *DeleteOr
} }
func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *DeleteOrgHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.log.Infof("deleteorghandler")
ctx := r.Context() ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
@ -136,6 +136,7 @@ func NewOrgsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *OrgsHandler {
} }
func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -163,7 +164,7 @@ func (h *OrgsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := query.Get("start") start := query.Get("start")
var orgs []*types.Organization var orgs []*types.Organization
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
orgs, err = h.readDB.GetOrgs(tx, start, limit, asc) orgs, err = h.readDB.GetOrgs(tx, start, limit, asc)
return err return err

View File

@ -15,6 +15,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@ -43,18 +44,18 @@ type Project struct {
GlobalVisibility types.Visibility GlobalVisibility types.Visibility
} }
func projectResponse(readDB *readdb.ReadDB, project *types.Project) (*Project, error) { func projectResponse(ctx context.Context, readDB *readdb.ReadDB, project *types.Project) (*Project, error) {
r, err := projectsResponse(readDB, []*types.Project{project}) r, err := projectsResponse(ctx, readDB, []*types.Project{project})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r[0], nil return r[0], nil
} }
func projectsResponse(readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) { func projectsResponse(ctx context.Context, readDB *readdb.ReadDB, projects []*types.Project) ([]*Project, error) {
resProjects := make([]*Project, len(projects)) resProjects := make([]*Project, len(projects))
err := readDB.Do(func(tx *db.Tx) error { err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, project := range projects { for i, project := range projects {
pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID) pp, err := readDB.GetPath(tx, project.Parent.Type, project.Parent.ID)
if err != nil { if err != nil {
@ -132,6 +133,7 @@ func NewProjectHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectHandle
} }
func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectRef, err := url.PathUnescape(vars["projectref"]) projectRef, err := url.PathUnescape(vars["projectref"])
if err != nil { if err != nil {
@ -140,7 +142,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
var project *types.Project var project *types.Project
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
project, err = h.readDB.GetProject(tx, projectRef) project, err = h.readDB.GetProject(tx, projectRef)
return err return err
@ -156,7 +158,7 @@ func (h *ProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -193,7 +195,7 @@ func (h *CreateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -241,7 +243,7 @@ func (h *UpdateProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProject, err := projectResponse(h.readDB, project) resProject, err := projectResponse(ctx, h.readDB, project)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return

View File

@ -15,6 +15,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@ -43,18 +44,18 @@ type ProjectGroup struct {
GlobalVisibility types.Visibility GlobalVisibility types.Visibility
} }
func projectGroupResponse(readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) { func projectGroupResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroup *types.ProjectGroup) (*ProjectGroup, error) {
r, err := projectGroupsResponse(readDB, []*types.ProjectGroup{projectGroup}) r, err := projectGroupsResponse(ctx, readDB, []*types.ProjectGroup{projectGroup})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r[0], nil return r[0], nil
} }
func projectGroupsResponse(readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) { func projectGroupsResponse(ctx context.Context, readDB *readdb.ReadDB, projectGroups []*types.ProjectGroup) ([]*ProjectGroup, error) {
resProjectGroups := make([]*ProjectGroup, len(projectGroups)) resProjectGroups := make([]*ProjectGroup, len(projectGroups))
err := readDB.Do(func(tx *db.Tx) error { err := readDB.Do(ctx, func(tx *db.Tx) error {
for i, projectGroup := range projectGroups { for i, projectGroup := range projectGroups {
pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID) pp, err := readDB.GetPath(tx, projectGroup.Parent.Type, projectGroup.Parent.ID)
if err != nil { if err != nil {
@ -100,7 +101,9 @@ func NewProjectGroupHandler(logger *zap.Logger, readDB *readdb.ReadDB) *ProjectG
} }
func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil { if err != nil {
httpError(w, util.NewErrBadRequest(err)) httpError(w, util.NewErrBadRequest(err))
@ -108,7 +111,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
} }
var projectGroup *types.ProjectGroup var projectGroup *types.ProjectGroup
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef) projectGroup, err = h.readDB.GetProjectGroup(tx, projectGroupRef)
return err return err
@ -124,7 +127,7 @@ func (h *ProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -148,6 +151,7 @@ func NewProjectGroupProjectsHandler(logger *zap.Logger, ah *action.ActionHandler
func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
projectGroupRef, err := url.PathUnescape(vars["projectgroupref"]) projectGroupRef, err := url.PathUnescape(vars["projectgroupref"])
if err != nil { if err != nil {
httpError(w, util.NewErrBadRequest(err)) httpError(w, util.NewErrBadRequest(err))
@ -160,7 +164,7 @@ func (h *ProjectGroupProjectsHandler) ServeHTTP(w http.ResponseWriter, r *http.R
return return
} }
resProjects, err := projectsResponse(h.readDB, projects) resProjects, err := projectsResponse(ctx, h.readDB, projects)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -196,7 +200,7 @@ func (h *ProjectGroupSubgroupsHandler) ServeHTTP(w http.ResponseWriter, r *http.
return return
} }
resProjectGroups, err := projectGroupsResponse(h.readDB, projectGroups) resProjectGroups, err := projectGroupsResponse(ctx, h.readDB, projectGroups)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -233,7 +237,7 @@ func (h *CreateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return
@ -281,7 +285,7 @@ func (h *UpdateProjectGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Req
return return
} }
resProjectGroup, err := projectGroupResponse(h.readDB, projectGroup) resProjectGroup, err := projectGroupResponse(ctx, h.readDB, projectGroup)
if httpError(w, err) { if httpError(w, err) {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
return return

View File

@ -40,11 +40,12 @@ func NewRemoteSourceHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteSo
} }
func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RemoteSourceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
rsRef := vars["remotesourceref"] rsRef := vars["remotesourceref"]
var remoteSource *types.RemoteSource var remoteSource *types.RemoteSource
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef) remoteSource, err = h.readDB.GetRemoteSource(tx, rsRef)
return err return err
@ -171,6 +172,7 @@ func NewRemoteSourcesHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RemoteS
} }
func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -197,7 +199,7 @@ func (h *RemoteSourcesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
start := query.Get("start") start := query.Get("start")
remoteSources, err := h.readDB.GetRemoteSources(start, limit, asc) remoteSources, err := h.readDB.GetRemoteSources(ctx, start, limit, asc)
if err != nil { if err != nil {
h.log.Errorf("err: %+v", err) h.log.Errorf("err: %+v", err)
httpError(w, err) httpError(w, err)

View File

@ -94,7 +94,7 @@ func (h *SecretsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resSecrets[i] = &Secret{Secret: s} resSecrets[i] = &Secret{Secret: s}
} }
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path // populate parent path
for _, s := range resSecrets { for _, s := range resSecrets {
pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID) pp, err := h.readDB.GetPath(tx, s.Parent.Type, s.Parent.ID)

View File

@ -41,11 +41,12 @@ func NewUserHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UserHandler {
} }
func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
userRef := vars["userref"] userRef := vars["userref"]
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUser(tx, userRef) user, err = h.readDB.GetUser(tx, userRef)
return err return err
@ -198,6 +199,7 @@ func NewUsersHandler(logger *zap.Logger, readDB *readdb.ReadDB) *UsersHandler {
} }
func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
limitS := query.Get("limit") limitS := query.Get("limit")
@ -232,7 +234,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bytoken": case "bytoken":
token := query.Get("token") token := query.Get("token")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByTokenValue(tx, token) user, err = h.readDB.GetUserByTokenValue(tx, token)
return err return err
@ -250,7 +252,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "bylinkedaccount": case "bylinkedaccount":
linkedAccountID := query.Get("linkedaccountid") linkedAccountID := query.Get("linkedaccountid")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID) user, err = h.readDB.GetUserByLinkedAccount(tx, linkedAccountID)
return err return err
@ -269,7 +271,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
remoteUserID := query.Get("remoteuserid") remoteUserID := query.Get("remoteuserid")
remoteSourceID := query.Get("remotesourceid") remoteSourceID := query.Get("remotesourceid")
var user *types.User var user *types.User
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID) user, err = h.readDB.GetUserByLinkedAccountRemoteUserIDandSource(tx, remoteUserID, remoteSourceID)
return err return err
@ -286,7 +288,7 @@ func (h *UsersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
users = []*types.User{user} users = []*types.User{user}
default: default:
// default query // default query
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
users, err = h.readDB.GetUsers(tx, start, limit, asc) users, err = h.readDB.GetUsers(tx, start, limit, asc)
return err return err

View File

@ -67,7 +67,7 @@ func (h *VariablesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for i, v := range variables { for i, v := range variables {
resVariables[i] = &Variable{Variable: v} resVariables[i] = &Variable{Variable: v}
} }
err = h.readDB.Do(func(tx *db.Tx) error { err = h.readDB.Do(ctx, func(tx *db.Tx) error {
// populate parent path // populate parent path
for _, v := range resVariables { for _, v := range resVariables {
pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID) pp, err := h.readDB.GetPath(tx, v.Parent.Type, v.Parent.ID)

View File

@ -99,9 +99,9 @@ func setupConfigstore(t *testing.T, ctx context.Context, dir string) (*Configsto
return cs, tetcd return cs, tetcd
} }
func getProjects(cs *Configstore) ([]*types.Project, error) { func getProjects(ctx context.Context, cs *Configstore) ([]*types.Project, error) {
var projects []*types.Project var projects []*types.Project
err := cs.readDB.Do(func(tx *db.Tx) error { err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
projects, err = cs.readDB.GetAllProjects(tx) projects, err = cs.readDB.GetAllProjects(tx)
return err return err
@ -109,9 +109,9 @@ func getProjects(cs *Configstore) ([]*types.Project, error) {
return projects, err return projects, err
} }
func getUsers(cs *Configstore) ([]*types.User, error) { func getUsers(ctx context.Context, cs *Configstore) ([]*types.User, error) {
var users []*types.User var users []*types.User
err := cs.readDB.Do(func(tx *db.Tx) error { err := cs.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
users, err = cs.readDB.GetUsers(tx, "", 0, true) users, err = cs.readDB.GetUsers(tx, "", 0, true)
return err return err
@ -242,12 +242,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users1, err := getUsers(cs1) users1, err := getUsers(ctx, cs1)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
users2, err := getUsers(cs2) users2, err := getUsers(ctx, cs2)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -275,12 +275,12 @@ func TestResync(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users1, err = getUsers(cs1) users1, err = getUsers(ctx, cs1)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
users3, err := getUsers(cs3) users3, err := getUsers(ctx, cs3)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -349,7 +349,7 @@ func TestUser(t *testing.T) {
} }
}) })
t.Run("concurrent user with same name creation", func(t *testing.T) { t.Run("concurrent user with same name creation", func(t *testing.T) {
prevUsers, err := getUsers(cs) prevUsers, err := getUsers(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -364,7 +364,7 @@ func TestUser(t *testing.T) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
users, err := getUsers(cs) users, err := getUsers(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -494,7 +494,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
}) })
t.Run("concurrent project with same name creation", func(t *testing.T) { t.Run("concurrent project with same name creation", func(t *testing.T) {
prevProjects, err := getProjects(cs) prevProjects, err := getProjects(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -511,7 +511,7 @@ func TestProjectGroupsAndProjects(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
projects, err := getProjects(cs) projects, err := getProjects(ctx, cs)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }

View File

@ -74,7 +74,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err return nil, err
} }
@ -112,7 +112,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil return nil
} }
func (r *ReadDB) ResetDB() error { func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex // TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close() r.rdb.Close()
@ -127,7 +127,7 @@ func (r *ReadDB) ResetDB() error {
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return err return err
} }
@ -136,7 +136,7 @@ func (r *ReadDB) ResetDB() error {
return nil return nil
} }
func (r *ReadDB) SyncFromDump() (string, error) { func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus() dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist { if err != nil && err != ostypes.ErrNotExist {
return "", err return "", err
@ -168,7 +168,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
dumpf.Close() dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries { for _, de := range dumpEntries {
action := &datamanager.Action{ action := &datamanager.Action{
ActionType: datamanager.ActionTypePut, ActionType: datamanager.ActionTypePut,
@ -188,7 +188,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil { if err := r.insertCommittedWalSequence(tx, dumpIndex.WalSequence); err != nil {
return err return err
} }
@ -201,9 +201,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil return dumpIndex.WalSequence, nil
} }
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error { insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles { for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence) walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil { if err != nil {
@ -263,7 +263,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) SyncRDB(ctx context.Context) error { func (r *ReadDB) SyncRDB(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb // get the last committed storage wal sequence saved in the rdb
curWalSeq := "" curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
curWalSeq, err = r.GetCommittedWalSequence(tx) curWalSeq, err = r.GetCommittedWalSequence(tx)
if err != nil { if err != nil {
@ -317,12 +317,12 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
if doFullSync { if doFullSync {
r.log.Infof("doing a full sync from dump") r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return err return err
} }
var err error var err error
curWalSeq, err = r.SyncFromDump() curWalSeq, err = r.SyncFromDump(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -335,7 +335,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
// etcd since wals are first committed to objectstorage and then in etcd we // etcd since wals are first committed to objectstorage and then in etcd we
// would like to avoid to store in rdb something that is not yet marked as // would like to avoid to store in rdb something that is not yet marked as
// committedstorage in etcd // committedstorage in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil { if err != nil {
return errors.Errorf("failed to sync from wals: %w", err) return errors.Errorf("failed to sync from wals: %w", err)
} }
@ -358,7 +358,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
r.log.Infof("syncing from wals") r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertRevision(tx, revision); err != nil { if err := r.insertRevision(tx, revision); err != nil {
return err return err
} }
@ -406,7 +406,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision() revision, err := r.GetRevision(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -475,7 +475,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
// lost/reset // lost/reset
func (r *ReadDB) handleEvents(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -508,7 +508,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if // if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset // the user hasn't messed up with etcd keys) this means etcd has been reset
@ -704,8 +704,8 @@ func (r *ReadDB) applyAction(tx *db.Tx, action *datamanager.Action) error {
return nil return nil
} }
func (r *ReadDB) Do(f func(tx *db.Tx) error) error { func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
return r.rdb.Do(f) return r.rdb.Do(ctx, f)
} }
func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error { func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
@ -724,10 +724,10 @@ func (r *ReadDB) insertRevision(tx *db.Tx, revision int64) error {
return nil return nil
} }
func (r *ReadDB) GetRevision() (int64, error) { func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
return err return err

View File

@ -15,6 +15,7 @@
package readdb package readdb
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -139,7 +140,7 @@ func getRemoteSourcesFilteredQuery(startRemoteSourceName string, limit int, asc
return s return s
} }
func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) { func (r *ReadDB) GetRemoteSources(ctx context.Context, startRemoteSourceName string, limit int, asc bool) ([]*types.RemoteSource, error) {
var remoteSources []*types.RemoteSource var remoteSources []*types.RemoteSource
s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc) s := getRemoteSourcesFilteredQuery(startRemoteSourceName, limit, asc)
@ -149,7 +150,7 @@ func (r *ReadDB) GetRemoteSources(startRemoteSourceName string, limit int, asc b
return nil, errors.Errorf("failed to build query: %w", err) return nil, errors.Errorf("failed to build query: %w", err)
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
rows, err := tx.Query(q, args...) rows, err := tx.Query(q, args...)
if err != nil { if err != nil {
return err return err

View File

@ -381,7 +381,7 @@ func (h *ActionHandler) saveRun(ctx context.Context, rb *types.RunBundle, runcgt
run := rb.Run run := rb.Run
rc := rb.Rc rc := rb.Rc
c, cgt, err := h.getRunCounter(run.Group) c, cgt, err := h.getRunCounter(ctx, run.Group)
h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt)) h.log.Debugf("c: %d, cgt: %s", c, util.Dump(cgt))
if err != nil { if err != nil {
return err return err
@ -570,7 +570,7 @@ func (h *ActionHandler) DeleteExecutor(ctx context.Context, executorID string) e
return nil return nil
} }
func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) { func (h *ActionHandler) getRunCounter(ctx context.Context, group string) (uint64, *datamanager.ChangeGroupsUpdateToken, error) {
// use the first group dir after the root // use the first group dir after the root
pl := util.PathList(group) pl := util.PathList(group)
if len(pl) < 2 { if len(pl) < 2 {
@ -579,7 +579,7 @@ func (h *ActionHandler) getRunCounter(group string) (uint64, *datamanager.Change
var c uint64 var c uint64
var cgt *datamanager.ChangeGroupsUpdateToken var cgt *datamanager.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
c, err = h.readDB.GetRunCounterOST(tx, pl[1]) c, err = h.readDB.GetRunCounterOST(tx, pl[1])
if err != nil { if err != nil {

View File

@ -323,12 +323,13 @@ func NewChangeGroupsUpdateTokensHandler(logger *zap.Logger, readDB *readdb.ReadD
} }
func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ChangeGroupsUpdateTokensHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
groups := query["changegroup"] groups := query["changegroup"]
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups) cgt, err = h.readDB.GetChangeGroupsUpdateTokens(tx, groups)
return err return err
@ -372,6 +373,7 @@ func NewRunHandler(logger *zap.Logger, e *etcd.Store, dm *datamanager.DataManage
} }
func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
vars := mux.Vars(r) vars := mux.Vars(r)
runID := vars["runid"] runID := vars["runid"]
@ -381,7 +383,7 @@ func (h *RunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var run *types.Run var run *types.Run
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
run, err = h.readDB.GetRun(tx, runID) run, err = h.readDB.GetRun(tx, runID)
if err != nil { if err != nil {
@ -447,6 +449,7 @@ func NewRunsHandler(logger *zap.Logger, readDB *readdb.ReadDB) *RunsHandler {
} }
func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query() query := r.URL.Query()
phaseFilter := types.RunPhaseFromStringSlice(query["phase"]) phaseFilter := types.RunPhaseFromStringSlice(query["phase"])
resultFilter := types.RunResultFromStringSlice(query["result"]) resultFilter := types.RunResultFromStringSlice(query["result"])
@ -482,7 +485,7 @@ func (h *RunsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var runs []*types.Run var runs []*types.Run
var cgt *types.ChangeGroupsUpdateToken var cgt *types.ChangeGroupsUpdateToken
err := h.readDB.Do(func(tx *db.Tx) error { err := h.readDB.Do(ctx, func(tx *db.Tx) error {
var err error var err error
runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder) runs, err = h.readDB.GetRuns(tx, groups, lastRun, phaseFilter, resultFilter, start, limit, sortOrder)
if err != nil { if err != nil {

View File

@ -119,7 +119,7 @@ func NewReadDB(ctx context.Context, logger *zap.Logger, dataDir string, e *etcd.
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return nil, err return nil, err
} }
@ -150,7 +150,7 @@ func (r *ReadDB) IsInitialized() bool {
// Initialize populates the readdb with the current etcd data and save the // Initialize populates the readdb with the current etcd data and save the
// revision to then feed it with the etcd events // revision to then feed it with the etcd events
func (r *ReadDB) Initialize(ctx context.Context) error { func (r *ReadDB) Initialize(ctx context.Context) error {
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return errors.Errorf("failed to reset db: %w", err) return errors.Errorf("failed to reset db: %w", err)
} }
if err := r.SyncObjectStorage(ctx); err != nil { if err := r.SyncObjectStorage(ctx); err != nil {
@ -162,7 +162,7 @@ func (r *ReadDB) Initialize(ctx context.Context) error {
return nil return nil
} }
func (r *ReadDB) ResetDB() error { func (r *ReadDB) ResetDB(ctx context.Context) error {
// TODO(sgotti) this needs to be protected by a mutex // TODO(sgotti) this needs to be protected by a mutex
r.rdb.Close() r.rdb.Close()
@ -177,7 +177,7 @@ func (r *ReadDB) ResetDB() error {
} }
// populate readdb // populate readdb
if err := rdb.Create(Stmts); err != nil { if err := rdb.Create(ctx, Stmts); err != nil {
return err return err
} }
@ -187,7 +187,7 @@ func (r *ReadDB) ResetDB() error {
} }
func (r *ReadDB) SyncRDB(ctx context.Context) error { func (r *ReadDB) SyncRDB(ctx context.Context) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
// Do pagination to limit the number of keys per request // Do pagination to limit the number of keys per request
var revision int64 var revision int64
key := common.EtcdRunsDir key := common.EtcdRunsDir
@ -256,7 +256,7 @@ func (r *ReadDB) SyncRDB(ctx context.Context) error {
} }
func (r *ReadDB) Run(ctx context.Context) error { func (r *ReadDB) Run(ctx context.Context) error {
revision, err := r.GetRevision() revision, err := r.GetRevision(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -333,7 +333,7 @@ func (r *ReadDB) Run(ctx context.Context) error {
func (r *ReadDB) handleEvents(ctx context.Context) error { func (r *ReadDB) handleEvents(ctx context.Context) error {
var revision int64 var revision int64
var lastRuns []*RunData var lastRuns []*RunData
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
if err != nil { if err != nil {
@ -390,7 +390,7 @@ func (r *ReadDB) handleEvents(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
r.dbWriteLock.Lock() r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
if err := r.handleEvent(tx, ev, &wresp); err != nil { if err := r.handleEvent(tx, ev, &wresp); err != nil {
return err return err
@ -503,7 +503,7 @@ func (r *ReadDB) handleChangeGroupEvent(tx *db.Tx, ev *etcdclientv3.Event, wresp
func (r *ReadDB) SyncObjectStorage(ctx context.Context) error { func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// get the last committed storage wal sequence saved in the rdb // get the last committed storage wal sequence saved in the rdb
curWalSeq := "" curWalSeq := ""
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
curWalSeq, err = r.GetCommittedWalSequenceOST(tx) curWalSeq, err = r.GetCommittedWalSequenceOST(tx)
if err != nil { if err != nil {
@ -557,12 +557,12 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
if doFullSync { if doFullSync {
r.log.Infof("doing a full sync from dump") r.log.Infof("doing a full sync from dump")
if err := r.ResetDB(); err != nil { if err := r.ResetDB(ctx); err != nil {
return err return err
} }
var err error var err error
curWalSeq, err = r.SyncFromDump() curWalSeq, err = r.SyncFromDump(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -575,7 +575,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
// since wals are first committed to objectstorage and then in etcd we would like to // since wals are first committed to objectstorage and then in etcd we would like to
// avoid to store in rdb something that is not yet marked as committedstorage // avoid to store in rdb something that is not yet marked as committedstorage
// in etcd // in etcd
curWalSeq, err = r.SyncFromWals(curWalSeq, lastCommittedStorageWal) curWalSeq, err = r.SyncFromWals(ctx, curWalSeq, lastCommittedStorageWal)
if err != nil { if err != nil {
return errors.Errorf("failed to sync from wals: %w", err) return errors.Errorf("failed to sync from wals: %w", err)
} }
@ -598,7 +598,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
} }
r.log.Infof("syncing from wals") r.log.Infof("syncing from wals")
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := insertRevisionOST(tx, revision); err != nil { if err := insertRevisionOST(tx, revision); err != nil {
return err return err
} }
@ -645,7 +645,7 @@ func (r *ReadDB) SyncObjectStorage(ctx context.Context) error {
return err return err
} }
func (r *ReadDB) SyncFromDump() (string, error) { func (r *ReadDB) SyncFromDump(ctx context.Context) (string, error) {
dumpIndex, err := r.dm.GetLastDataStatus() dumpIndex, err := r.dm.GetLastDataStatus()
if err != nil && err != ostypes.ErrNotExist { if err != nil && err != ostypes.ErrNotExist {
return "", err return "", err
@ -677,7 +677,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
dumpf.Close() dumpf.Close()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, de := range dumpEntries { for _, de := range dumpEntries {
action := &datamanager.Action{ action := &datamanager.Action{
ActionType: datamanager.ActionTypePut, ActionType: datamanager.ActionTypePut,
@ -697,7 +697,7 @@ func (r *ReadDB) SyncFromDump() (string, error) {
} }
} }
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil { if err := r.insertCommittedWalSequenceOST(tx, dumpIndex.WalSequence); err != nil {
return err return err
} }
@ -710,9 +710,9 @@ func (r *ReadDB) SyncFromDump() (string, error) {
return dumpIndex.WalSequence, nil return dumpIndex.WalSequence, nil
} }
func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) { func (r *ReadDB) SyncFromWals(ctx context.Context, startWalSeq, endWalSeq string) (string, error) {
insertfunc := func(walFiles []*datamanager.WalFile) error { insertfunc := func(walFiles []*datamanager.WalFile) error {
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
for _, walFile := range walFiles { for _, walFile := range walFiles {
walFilef, err := r.dm.ReadWal(walFile.WalSequence) walFilef, err := r.dm.ReadWal(walFile.WalSequence)
if err != nil { if err != nil {
@ -771,7 +771,7 @@ func (r *ReadDB) SyncFromWals(startWalSeq, endWalSeq string) (string, error) {
func (r *ReadDB) handleEventsOST(ctx context.Context) error { func (r *ReadDB) handleEventsOST(ctx context.Context) error {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision) err := tx.QueryRow("select revision from revision order by revision desc limit 1").Scan(&revision)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -805,7 +805,7 @@ func (r *ReadDB) handleEventsOST(ctx context.Context) error {
// a single transaction for every response (every response contains all the // a single transaction for every response (every response contains all the
// events happened in an etcd revision). // events happened in an etcd revision).
r.dbWriteLock.Lock() r.dbWriteLock.Lock()
err = r.rdb.Do(func(tx *db.Tx) error { err = r.rdb.Do(ctx, func(tx *db.Tx) error {
// if theres a wal seq epoch change something happened to etcd, usually (if // if theres a wal seq epoch change something happened to etcd, usually (if
// the user hasn't messed up with etcd keys) this means etcd has been reset // the user hasn't messed up with etcd keys) this means etcd has been reset
@ -949,11 +949,11 @@ func (r *ReadDB) handleWalEvent(tx *db.Tx, we *datamanager.WatchElement) error {
return nil return nil
} }
func (r *ReadDB) Do(f func(tx *db.Tx) error) error { func (r *ReadDB) Do(ctx context.Context, f func(tx *db.Tx) error) error {
if !r.IsInitialized() { if !r.IsInitialized() {
return errors.Errorf("db not initialized") return errors.Errorf("db not initialized")
} }
return r.rdb.Do(f) return r.rdb.Do(ctx, f)
} }
func insertRevision(tx *db.Tx, revision int64) error { func insertRevision(tx *db.Tx, revision int64) error {
@ -1073,10 +1073,10 @@ func insertChangeGroupRevision(tx *db.Tx, changegroupID string, revision int64)
return nil return nil
} }
func (r *ReadDB) GetRevision() (int64, error) { func (r *ReadDB) GetRevision(ctx context.Context) (int64, error) {
var revision int64 var revision int64
err := r.rdb.Do(func(tx *db.Tx) error { err := r.rdb.Do(ctx, func(tx *db.Tx) error {
var err error var err error
revision, err = r.getRevision(tx) revision, err = r.getRevision(tx)
return err return err