nat/lib/store/sqlike/sqlike.go

329 lines
8.2 KiB
Go

package sqlike
import (
"context"
"crypto/sha1"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"html"
"strconv"
"time"
"git.tuxpa.in/a/nat/lib/idgen"
"git.tuxpa.in/a/nat/lib/store"
"github.com/dchest/uniuri"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type SqlikeConfig struct {
DBHost string `json:"dbhost"` // Name of your database host
DBName string `json:"dbname"` // Name of your database
DBPassword string `json:"dbpassword"` // The password for the database user
DBPlaceHolder [7]string // ? / $[i] Depending on db driver.
DBPort string `json:"dbport"` // Port of the database
DBTable string `json:"dbtable"` // Name of the table in the database
DBAccountsTable string `json:"dbaccountstable"` // Name of the table in the database
DBType string `json:"dbtype"` // Type of database
DBUser string `json:"dbuser"` // The database user
DisplayName string `json:"displayname"` // Name of your pastebin
}
var _ store.Store = (*Sqlike)(nil)
type Sqlike struct {
config SqlikeConfig
handle *sql.DB
}
func MustNew(config ...SqlikeConfig) *Sqlike {
o, err := New(config...)
if err != nil {
panic(err)
}
return o
}
func New(config ...SqlikeConfig) (*Sqlike, error) {
s := &Sqlike{}
// default settings
s.config.DBType = "sqlite3"
s.config.DBName = "db.sqlite"
s.config.DBTable = "pastebin"
s.config.DBAccountsTable = "accounts"
if len(config) > 0 {
s.config = config[0]
}
if err := s.connect(); err != nil {
return nil, err
}
return s, nil
}
func (s *Sqlike) connect() error {
var dbinfo string
for i := 0; i < 7; i++ {
s.config.DBPlaceHolder[i] = "?"
}
switch s.config.DBType {
case "sqlite3":
dbinfo = s.config.DBName
case "postgres":
dbinfo = fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
s.config.DBHost,
s.config.DBPort,
s.config.DBUser,
s.config.DBPassword,
s.config.DBName)
for i := 0; i < 7; i++ {
s.config.DBPlaceHolder[i] = "$" + strconv.Itoa(i+1)
}
case "mysql":
dbinfo = s.config.DBUser + ":" + s.config.DBPassword + "@tcp(" + s.config.DBHost + ":" + s.config.DBPort + ")/" + s.config.DBName
case "":
return errors.New(" Database error : dbtype not specified in sqlike config")
default:
return errors.New(" Database error : Specified dbtype (" +
s.config.DBType + ") not supported.")
}
db, err := sql.Open(s.config.DBType, dbinfo)
if err != nil {
return err
}
var dummy string
err = db.QueryRow("select id from " + s.config.DBTable + " where id='dummyid'").Scan(&dummy)
switch {
case err == sql.ErrNoRows:
case err != nil:
return err
}
s.handle = db
return nil
}
func shaPaste(paste string) string {
hasher := sha1.New()
hasher.Write([]byte(paste))
sha := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return sha
}
func (s *Sqlike) SavePaste(ctx context.Context, title string, data string, expiry time.Duration, userKey string) (*store.Response, error) {
var id, hash, delkey string
// Escape user input,
data = html.EscapeString(data)
title = html.EscapeString(title)
userKey = html.EscapeString(userKey)
// Hash paste data and query database to see if paste exists
sha := shaPaste(data)
err := s.handle.QueryRow("select id, title, hash, data, delkey from "+
s.config.DBTable+" where hash="+
s.config.DBPlaceHolder[0], sha).Scan(&id,
&title, &hash, &data, &delkey)
switch {
case err == sql.ErrNoRows:
case err != nil:
return nil, err
default:
return &store.Response{
Status: "Paste data already exists ...",
Id: id,
Title: title,
Sha1: hash,
Size: len(data)}, nil
}
// Generate id,
id = idgen.MustGenerate()
expiretime := time.Now().Add(expiry)
// Set the generated id as title if not given,
if title == "" {
title = id
}
delKey := uniuri.NewLen(40)
// This is needed since mysql/postgres uses different placeholders,
var dbQuery string
for i := 0; i < 7; i++ {
dbQuery += s.config.DBPlaceHolder[i] + ","
}
dbQuery = dbQuery[:len(dbQuery)-1]
stmt, err := s.handle.Prepare("INSERT INTO " + s.config.DBTable + " (id,title,hash,data,delkey,expiry,userid)values(" + dbQuery + ")")
if err != nil {
return nil, err
}
_, err = stmt.Exec(id, title, sha, data, delKey, expiretime, userKey)
if err != nil {
return nil, err
}
stmt.Close()
if err != nil {
return nil, err
}
return &store.Response{
Status: "Successfully saved paste.",
Id: id,
Title: title,
Sha1: hash,
Size: len(data),
DelKey: delKey}, nil
}
func (s *Sqlike) GetUserPastes(ctx context.Context, userKey string) (*store.Pastes, error) {
pst := &store.Pastes{}
rows, err := s.handle.Query("select id, title, delkey, data from "+
s.config.DBTable+" where userid="+
s.config.DBPlaceHolder[0], userKey)
switch {
case err == sql.ErrNoRows:
case err != nil:
return nil, err
default:
defer rows.Close()
for rows.Next() {
var id, title, delKey, data string
rows.Scan(&id, &title, &delKey, &data)
res := store.Response{
Id: id,
Title: title,
Size: len(data),
DelKey: delKey}
pst.Response = append(pst.Response, res)
}
}
return pst, nil
}
func (s *Sqlike) GetUserKey(ctx context.Context, email string) (string, error) {
var user_key string
err := s.handle.QueryRowContext(ctx, "select key from "+s.config.DBAccountsTable+
" where email="+s.config.DBPlaceHolder[0], email).
Scan(&user_key)
if err != nil {
return "", err
}
return user_key, nil
}
func (s *Sqlike) GetPaste(ctx context.Context, pasteId string) (*store.Response, error) {
var title, paste string
var expiry int64
err := s.handle.QueryRowContext(ctx, "select title, data, expiry from "+
s.config.DBTable+" where id="+s.config.DBPlaceHolder[0],
pasteId).Scan(&title, &paste, &expiry)
switch {
case err == sql.ErrNoRows:
return &store.Response{Status: "Requested paste doesn't exist."}, nil
case err != nil:
return nil, err
}
expiretime := time.Unix(expiry, 0)
if expiry == 0 {
expiretime = time.Time{}
}
// Check if paste is overdue,
ok := time.Now().After(expiretime)
if err != nil || !ok {
return &store.Response{Status: "Requested paste doesn't exist."}, nil
}
// Unescape the saved data,
paste = html.UnescapeString(paste)
title = html.UnescapeString(title)
expiryS := "Never"
if expiry != 0 {
expiryS = time.Unix(expiry, 0).Format("2006-01-02 15:04:05")
}
r := &store.Response{
Status: "Success",
Id: pasteId,
Title: title,
Paste: paste,
Size: len(paste),
Expiry: expiryS}
return r, nil
}
func (s *Sqlike) ForceDelPaste(ctx context.Context, pasteId string) error {
stmt, err := s.handle.PrepareContext(ctx, "delete from pastebin where id="+
s.config.DBPlaceHolder[0])
if err != nil {
return err
}
defer stmt.Close()
// Execute it,
_, err = stmt.ExecContext(ctx, pasteId)
if err != nil {
return err
}
return nil
}
func (s *Sqlike) RegisterUser(ctx context.Context, email string, hashpass []byte) error {
var dbQuery string
for i := 0; i < 3; i++ {
dbQuery += s.config.DBPlaceHolder[i] + ","
}
dbQuery = dbQuery[:len(dbQuery)-1]
stmt, err := s.handle.PrepareContext(ctx, "INSERT into "+s.config.DBAccountsTable+"(email, password, key) values("+dbQuery+")")
defer stmt.Close()
if err != nil {
return err
}
key := idgen.MustGenerate()
_, err = stmt.ExecContext(ctx, email, hashpass, key)
if err != nil {
return err
}
return nil
}
func (s *Sqlike) DelPaste(ctx context.Context, pasteId, delKey string) error {
stmt, err := s.handle.PrepareContext(ctx, "delete from pastebin where delkey="+
s.config.DBPlaceHolder[0]+" and id="+
s.config.DBPlaceHolder[1])
if err != nil {
return err
}
defer stmt.Close()
res, err := stmt.ExecContext(ctx, delKey, pasteId)
_, err = res.RowsAffected()
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return err
}
return nil
}
func (s *Sqlike) HasAccount(ctx context.Context, email string) ([]byte, error) {
var hashedPassword []byte
err := s.handle.QueryRowContext(ctx, "select password from "+s.config.DBAccountsTable+
" where email="+s.config.DBPlaceHolder[0], email).
Scan(&hashedPassword)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return hashedPassword, nil
}