3297244d5d
Since we are using the shared cache with the lock notify we won't receive SQLITE_BUSY errors but we could receive SQLITE_LOCKED errors due to deadlocks or locked tables on concurrent read and write transactions. This patch catches this kind of errors and retries the tx until maxTxRetries.
275 lines
6.0 KiB
Go
275 lines
6.0 KiB
Go
// Copyright 2019 Sorint.lab
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"regexp"
|
|
"time"
|
|
|
|
"github.com/mattn/go-sqlite3"
|
|
errors "golang.org/x/xerrors"
|
|
)
|
|
|
|
type Type string
|
|
|
|
const (
|
|
Sqlite3 Type = "sqlite3"
|
|
Postgres Type = "postgres"
|
|
|
|
maxTxRetries = 20
|
|
)
|
|
|
|
type dbData struct {
|
|
t Type
|
|
queryReplacers []replacer
|
|
supportsTimezones bool
|
|
}
|
|
|
|
type replacer struct {
|
|
re *regexp.Regexp
|
|
with string
|
|
}
|
|
|
|
// match a postgres query bind variable. E.g. "$1", "$12", etc.
|
|
var bindRegexp = regexp.MustCompile(`\$\d+`)
|
|
|
|
func matchLiteral(s string) *regexp.Regexp {
|
|
return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
|
|
}
|
|
|
|
var (
|
|
dbDataPostgres = dbData{
|
|
t: Postgres,
|
|
supportsTimezones: true,
|
|
queryReplacers: []replacer{
|
|
// Remove sqlite3 only statements
|
|
{regexp.MustCompile(`--SQLITE3\n.*`), ""},
|
|
},
|
|
}
|
|
|
|
dbDataSQLite3 = dbData{
|
|
t: Sqlite3,
|
|
supportsTimezones: false,
|
|
queryReplacers: []replacer{
|
|
{bindRegexp, "?"},
|
|
{matchLiteral("true"), "1"},
|
|
{matchLiteral("false"), "0"},
|
|
{matchLiteral("boolean"), "integer"},
|
|
{matchLiteral("bytea"), "blob"},
|
|
// timestamp is a declared type suported by the go-sqlite3 driver
|
|
{matchLiteral("timestamptz"), "timestamp"},
|
|
// convert now to the max precision time available with sqlite3
|
|
{regexp.MustCompile(`\bnow\(\)`), "strftime('%Y-%m-%d %H:%M:%f', 'now')"},
|
|
{regexp.MustCompile(`select pg_advisory_xact_lock\(.*`), "select 1"},
|
|
{regexp.MustCompile(`notify\s+.*`), "select 1"},
|
|
// Remove postgres only statements
|
|
{regexp.MustCompile(`--POSTGRES\n.*`), ""},
|
|
},
|
|
}
|
|
)
|
|
|
|
func (t dbData) translate(query string) string {
|
|
for _, r := range t.queryReplacers {
|
|
query = r.re.ReplaceAllString(query, r.with)
|
|
}
|
|
return query
|
|
}
|
|
|
|
// translateArgs translates query parameters that may be unique to
|
|
// a specific SQL flavor. For example, standardizing "time.Time"
|
|
// types to UTC for clients that don't provide timezone support.
|
|
func (t dbData) translateArgs(args []interface{}) []interface{} {
|
|
if t.supportsTimezones {
|
|
return args
|
|
}
|
|
|
|
for i, arg := range args {
|
|
if t, ok := arg.(time.Time); ok {
|
|
args[i] = t.UTC()
|
|
}
|
|
}
|
|
return args
|
|
}
|
|
|
|
// DB wraps a sql.DB to add special behaviors based on the db type
|
|
type DB struct {
|
|
db *sql.DB
|
|
data dbData
|
|
}
|
|
|
|
func NewDB(dbType Type, dbConnString string) (*DB, error) {
|
|
var data dbData
|
|
var driverName string
|
|
switch dbType {
|
|
case Postgres:
|
|
data = dbDataPostgres
|
|
driverName = "postgres"
|
|
case Sqlite3:
|
|
data = dbDataSQLite3
|
|
driverName = "sqlite3"
|
|
dbConnString = "file:" + dbConnString + "?cache=shared&_journal=wal&_foreign_keys=true&_case_sensitive_like=false"
|
|
default:
|
|
return nil, errors.New("unknown db type")
|
|
}
|
|
|
|
sqldb, err := sql.Open(driverName, dbConnString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db := &DB{
|
|
db: sqldb,
|
|
data: data,
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// Tx wraps a sql.Tx to offer:
|
|
// * apply some statement mutations before executing it
|
|
// * locking around concurrent executions of statements (since the underlying
|
|
// sql driver doesn't support concurrent statements on the same
|
|
// connection/transaction)
|
|
type Tx struct {
|
|
db *DB
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (db *DB) Close() error {
|
|
return db.db.Close()
|
|
}
|
|
|
|
func (db *DB) Conn() (*sql.Conn, error) {
|
|
return db.db.Conn(context.TODO())
|
|
}
|
|
|
|
func (db *DB) NewUnstartedTx() *Tx {
|
|
return &Tx{
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func (db *DB) NewTx() (*Tx, error) {
|
|
tx := db.NewUnstartedTx()
|
|
if err := tx.Start(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tx, nil
|
|
}
|
|
|
|
func (db *DB) Do(f func(tx *Tx) error) error {
|
|
retries := 0
|
|
for {
|
|
err := db.do(f)
|
|
if err != nil {
|
|
var sqerr sqlite3.Error
|
|
if errors.As(err, &sqerr) {
|
|
if sqerr.Code == sqlite3.ErrLocked {
|
|
retries++
|
|
if retries <= maxTxRetries {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (db *DB) do(f func(tx *Tx) error) error {
|
|
tx, err := db.NewTx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback()
|
|
panic(p)
|
|
}
|
|
}()
|
|
if err = f(tx); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (tx *Tx) Start() error {
|
|
wtx, err := tx.db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
switch tx.db.data.t {
|
|
case Postgres:
|
|
if _, err := wtx.Exec("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
tx.tx = wtx
|
|
return nil
|
|
}
|
|
|
|
func (tx *Tx) Commit() error {
|
|
if tx.tx == nil {
|
|
return nil
|
|
}
|
|
return tx.tx.Commit()
|
|
}
|
|
|
|
func (tx *Tx) Rollback() error {
|
|
if tx.tx == nil {
|
|
return nil
|
|
}
|
|
return tx.tx.Rollback()
|
|
}
|
|
|
|
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
query = tx.db.data.translate(query)
|
|
r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...)
|
|
return r, err
|
|
}
|
|
|
|
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
query = tx.db.data.translate(query)
|
|
r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...)
|
|
return r, err
|
|
}
|
|
|
|
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
query = tx.db.data.translate(query)
|
|
return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...)
|
|
}
|
|
|
|
func (tx *Tx) CurTime() (time.Time, error) {
|
|
switch tx.db.data.t {
|
|
case Sqlite3:
|
|
var timestring string
|
|
if err := tx.QueryRow("select now()").Scan(×tring); err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
return time.ParseInLocation("2006-01-02 15:04:05.999999999", timestring, time.UTC)
|
|
case Postgres:
|
|
var now time.Time
|
|
if err := tx.QueryRow("select now()").Scan(&now); err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
return now, nil
|
|
}
|
|
return time.Time{}, errors.New("unknown db type")
|
|
}
|