agola/internal/db/db.go

279 lines
6.5 KiB
Go

// Copyright 2019 Sorint.lab
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
// See the License for the specific language governing permissions and
// limitations under the License.
package db
import (
"context"
"database/sql"
"regexp"
"time"
"agola.io/agola/internal/errors"
"github.com/mattn/go-sqlite3"
)
type Type string
const (
Sqlite3 Type = "sqlite3"
Postgres Type = "postgres"
maxTxRetries = 20
)
type dbData struct {
t Type
queryReplacers []replacer
supportsTimezones bool
}
type replacer struct {
re *regexp.Regexp
with string
}
// match a postgres query bind variable. E.g. "$1", "$12", etc.
var bindRegexp = regexp.MustCompile(`\$\d+`)
func matchLiteral(s string) *regexp.Regexp {
return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
}
var (
dbDataPostgres = dbData{
t: Postgres,
supportsTimezones: true,
queryReplacers: []replacer{
// Remove sqlite3 only statements
{regexp.MustCompile(`--SQLITE3\n.*`), ""},
},
}
dbDataSQLite3 = dbData{
t: Sqlite3,
supportsTimezones: false,
queryReplacers: []replacer{
{bindRegexp, "?"},
{matchLiteral("true"), "1"},
{matchLiteral("false"), "0"},
{matchLiteral("boolean"), "integer"},
{matchLiteral("bytea"), "blob"},
// timestamp is a declared type suported by the go-sqlite3 driver
{matchLiteral("timestamptz"), "timestamp"},
// convert now to the max precision time available with sqlite3
{regexp.MustCompile(`\bnow\(\)`), "strftime('%Y-%m-%d %H:%M:%f', 'now')"},
{regexp.MustCompile(`select pg_advisory_xact_lock\(.*`), "select 1"},
{regexp.MustCompile(`notify\s+.*`), "select 1"},
// Remove postgres only statements
{regexp.MustCompile(`--POSTGRES\n.*`), ""},
},
}
)
func (t dbData) translate(query string) string {
for _, r := range t.queryReplacers {
query = r.re.ReplaceAllString(query, r.with)
}
return query
}
// translateArgs translates query parameters that may be unique to
// a specific SQL flavor. For example, standardizing "time.Time"
// types to UTC for clients that don't provide timezone support.
func (t dbData) translateArgs(args []interface{}) []interface{} {
if t.supportsTimezones {
return args
}
for i, arg := range args {
if t, ok := arg.(time.Time); ok {
args[i] = t.UTC()
}
}
return args
}
// DB wraps a sql.DB to add special behaviors based on the db type
type DB struct {
db *sql.DB
data dbData
}
func NewDB(dbType Type, dbConnString string) (*DB, error) {
var data dbData
var driverName string
switch dbType {
case Postgres:
data = dbDataPostgres
driverName = "postgres"
case Sqlite3:
data = dbDataSQLite3
driverName = "sqlite3"
dbConnString = "file:" + dbConnString + "?cache=shared&_journal=wal&_foreign_keys=true&_case_sensitive_like=false"
default:
return nil, errors.New("unknown db type")
}
sqldb, err := sql.Open(driverName, dbConnString)
if err != nil {
return nil, errors.WithStack(err)
}
db := &DB{
db: sqldb,
data: data,
}
return db, nil
}
// Tx wraps a sql.Tx to offer:
// * apply some statement mutations before executing it
// * locking around concurrent executions of statements (since the underlying
// sql driver doesn't support concurrent statements on the same
// connection/transaction)
type Tx struct {
db *DB
tx *sql.Tx
ctx context.Context
}
func (db *DB) Close() error {
return errors.WithStack(db.db.Close())
}
func (db *DB) Conn(ctx context.Context) (*sql.Conn, error) {
c, err := db.db.Conn(ctx)
return c, errors.WithStack(err)
}
func (db *DB) NewUnstartedTx() *Tx {
return &Tx{
db: db,
}
}
func (db *DB) NewTx(ctx context.Context) (*Tx, error) {
tx := db.NewUnstartedTx()
if err := tx.Start(ctx); err != nil {
return nil, errors.WithStack(err)
}
return tx, nil
}
func (db *DB) Do(ctx context.Context, f func(tx *Tx) error) error {
retries := 0
for {
err := db.do(ctx, f)
if err != nil {
var sqerr sqlite3.Error
if errors.As(err, &sqerr) {
if sqerr.Code == sqlite3.ErrLocked {
retries++
if retries <= maxTxRetries {
continue
}
}
}
}
return errors.WithStack(err)
}
}
func (db *DB) do(ctx context.Context, f func(tx *Tx) error) error {
tx, err := db.NewTx(ctx)
if err != nil {
return errors.WithStack(err)
}
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p)
}
}()
if err = f(tx); err != nil {
_ = tx.Rollback()
return errors.WithStack(err)
}
return tx.Commit()
}
func (tx *Tx) Start(ctx context.Context) error {
wtx, err := tx.db.db.Begin()
if err != nil {
return errors.WithStack(err)
}
switch tx.db.data.t {
case Postgres:
if _, err := wtx.Exec("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); err != nil {
return errors.WithStack(err)
}
}
tx.tx = wtx
tx.ctx = ctx
return nil
}
func (tx *Tx) Commit() error {
if tx.tx == nil {
return nil
}
return errors.WithStack(tx.tx.Commit())
}
func (tx *Tx) Rollback() error {
if tx.tx == nil {
return nil
}
return errors.WithStack(tx.tx.Rollback())
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
query = tx.db.data.translate(query)
r, err := tx.tx.ExecContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, errors.WithStack(err)
}
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
query = tx.db.data.translate(query)
r, err := tx.tx.QueryContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
return r, errors.WithStack(err)
}
func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
query = tx.db.data.translate(query)
return tx.tx.QueryRowContext(tx.ctx, query, tx.db.data.translateArgs(args)...)
}
func (tx *Tx) CurTime() (time.Time, error) {
switch tx.db.data.t {
case Sqlite3:
var timestring string
if err := tx.QueryRow("select now()").Scan(&timestring); err != nil {
return time.Time{}, errors.WithStack(err)
}
t, err := time.ParseInLocation("2006-01-02 15:04:05.999999999", timestring, time.UTC)
return t, errors.WithStack(err)
case Postgres:
var now time.Time
if err := tx.QueryRow("select now()").Scan(&now); err != nil {
return time.Time{}, errors.WithStack(err)
}
return now, nil
}
return time.Time{}, errors.New("unknown db type")
}