// Copyright 2019 Sorint.lab // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied // See the License for the specific language governing permissions and // limitations under the License. package db import ( "database/sql" sq "github.com/Masterminds/squirrel" errors "golang.org/x/xerrors" ) const dbVersionTableDDLTmpl = ` create table if not exists dbversion (version int not null, time timestamptz not null) ` const dbVersion = 1 func (db *DB) Create(stmts []string) error { sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) err := db.Do(func(tx *Tx) error { if _, err := tx.Exec(dbVersionTableDDLTmpl); err != nil { return errors.Errorf("failed to create dbversion table: %w", err) } return nil }) if err != nil { return err } err = db.Do(func(tx *Tx) error { var version sql.NullInt64 q, args, err := sb.Select("max(version)").From("dbversion").ToSql() if err != nil { return err } if err := tx.QueryRow(q, args...).Scan(&version); err != nil { return errors.Errorf("cannot get current db version: %w", err) } if version.Valid { return nil } for _, stmt := range stmts { if _, err := tx.Exec(stmt); err != nil { return errors.Errorf("creation failed: %w", err) } } q, args, err = sb.Insert("dbversion").Columns("version", "time").Values(dbVersion, "now()").ToSql() if err != nil { return err } if _, err := tx.Exec(q, args...); err != nil { return errors.Errorf("failed to update dbversion table: %w", err) } return nil }) return err }