// Copyright 2019 Sorint.lab
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
// See the License for the specific language governing permissions and
// limitations under the License.

package db

import (
	"context"
	"database/sql"
	"regexp"
	"time"

	_ "github.com/mattn/go-sqlite3"
	"github.com/pkg/errors"
)

type Type string

const (
	Sqlite3  Type = "sqlite3"
	Postgres Type = "postgres"
)

type dbData struct {
	t                 Type
	queryReplacers    []replacer
	supportsTimezones bool
}

type replacer struct {
	re   *regexp.Regexp
	with string
}

// match a postgres query bind variable. E.g. "$1", "$12", etc.
var bindRegexp = regexp.MustCompile(`\$\d+`)

func matchLiteral(s string) *regexp.Regexp {
	return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
}

var (
	dbDataPostgres = dbData{
		t:                 Postgres,
		supportsTimezones: true,
		queryReplacers: []replacer{
			// Remove sqlite3 only statements
			{regexp.MustCompile(`--SQLITE3\n.*`), ""},
		},
	}

	dbDataSQLite3 = dbData{
		t:                 Sqlite3,
		supportsTimezones: false,
		queryReplacers: []replacer{
			{bindRegexp, "?"},
			{matchLiteral("true"), "1"},
			{matchLiteral("false"), "0"},
			{matchLiteral("boolean"), "integer"},
			{matchLiteral("bytea"), "blob"},
			// timestamp is a declared type suported by the go-sqlite3 driver
			{matchLiteral("timestamptz"), "timestamp"},
			// convert now to the max precision time available with sqlite3
			{regexp.MustCompile(`\bnow\(\)`), "strftime('%Y-%m-%d %H:%M:%f', 'now')"},
			{regexp.MustCompile(`select pg_advisory_xact_lock\(.*`), "select 1"},
			{regexp.MustCompile(`notify\s+.*`), "select 1"},
			// Remove postgres only statements
			{regexp.MustCompile(`--POSTGRES\n.*`), ""},
		},
	}
)

func (t dbData) translate(query string) string {
	for _, r := range t.queryReplacers {
		query = r.re.ReplaceAllString(query, r.with)
	}
	return query
}

// translateArgs translates query parameters that may be unique to
// a specific SQL flavor. For example, standardizing "time.Time"
// types to UTC for clients that don't provide timezone support.
func (t dbData) translateArgs(args []interface{}) []interface{} {
	if t.supportsTimezones {
		return args
	}

	for i, arg := range args {
		if t, ok := arg.(time.Time); ok {
			args[i] = t.UTC()
		}
	}
	return args
}

// DB wraps a sql.DB to add special behaviors based on the db type
type DB struct {
	db   *sql.DB
	data dbData
}

func NewDB(dbType Type, dbConnString string) (*DB, error) {
	var data dbData
	var driverName string
	switch dbType {
	case Postgres:
		data = dbDataPostgres
		driverName = "postgres"
	case Sqlite3:
		data = dbDataSQLite3
		driverName = "sqlite3"
		dbConnString = "file:" + dbConnString + "?cache=shared&_journal=wal&_foreign_keys=true&_case_sensitive_like=false"
	default:
		return nil, errors.New("unknown db type")
	}

	sqldb, err := sql.Open(driverName, dbConnString)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	db := &DB{
		db:   sqldb,
		data: data,
	}

	return db, nil
}

// Tx wraps a sql.Tx to offer:
// * apply some statement mutations before executing it
// * locking around concurrent executions of statements (since the underlying
// sql driver doesn't support concurrent statements on the same
// connection/transaction)
type Tx struct {
	db *DB
	tx *sql.Tx
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) Conn() (*sql.Conn, error) {
	return db.db.Conn(context.TODO())
}

func (db *DB) NewUnstartedTx() *Tx {
	return &Tx{
		db: db,
	}
}

func (db *DB) NewTx() (*Tx, error) {
	tx := db.NewUnstartedTx()
	if err := tx.Start(); err != nil {
		return nil, err
	}

	return tx, nil
}

func (db *DB) Do(f func(tx *Tx) error) error {
	tx, err := db.NewTx()
	if err != nil {
		return err
	}
	defer func() {
		if p := recover(); p != nil {
			tx.Rollback()
			panic(p)
		}
	}()
	if err = f(tx); err != nil {
		tx.Rollback()
		return err
	}
	return tx.Commit()
}

func (tx *Tx) Start() error {
	wtx, err := tx.db.db.Begin()
	if err != nil {
		return errors.WithStack(err)
	}
	switch tx.db.data.t {
	case Postgres:
		if _, err := wtx.Exec("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); err != nil {
			return errors.WithStack(err)
		}
	}
	tx.tx = wtx
	return nil
}

func (tx *Tx) Commit() error {
	if tx.tx == nil {
		return nil
	}
	return tx.tx.Commit()
}

func (tx *Tx) Rollback() error {
	if tx.tx == nil {
		return nil
	}
	return tx.tx.Rollback()
}

func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
	query = tx.db.data.translate(query)
	r, err := tx.tx.Exec(query, tx.db.data.translateArgs(args)...)
	return r, errors.WithStack(err)
}

func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
	query = tx.db.data.translate(query)
	r, err := tx.tx.Query(query, tx.db.data.translateArgs(args)...)
	return r, errors.WithStack(err)
}

func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row {
	query = tx.db.data.translate(query)
	return tx.tx.QueryRow(query, tx.db.data.translateArgs(args)...)
}

func (tx *Tx) CurTime() (time.Time, error) {
	switch tx.db.data.t {
	case Sqlite3:
		var timestring string
		if err := tx.QueryRow("select now()").Scan(&timestring); err != nil {
			return time.Time{}, errors.WithStack(err)
		}
		return time.ParseInLocation("2006-01-02 15:04:05.999999999", timestring, time.UTC)
	case Postgres:
		var now time.Time
		if err := tx.QueryRow("select now()").Scan(&now); err != nil {
			return time.Time{}, errors.WithStack(err)
		}
		return now, nil
	}
	return time.Time{}, errors.New("unknown db type")
}