gosora/query_gen/utils.go

523 lines
12 KiB
Go

/*
*
* Query Generator Library
* WIP Under Construction
* Copyright Azareal 2017 - 2020
*
*/
package qgen
import (
"os"
"strings"
)
// TODO: Add support for numbers and strings?
func processColumns(colStr string) (columns []DBColumn) {
if colStr == "" {
return columns
}
colStr = strings.Replace(colStr, " as ", " AS ", -1)
for _, segment := range strings.Split(colStr, ",") {
var outCol DBColumn
dotHalves := strings.Split(strings.TrimSpace(segment), ".")
var halves []string
if len(dotHalves) == 2 {
outCol.Table = dotHalves[0]
halves = strings.Split(dotHalves[1], " AS ")
} else {
halves = strings.Split(dotHalves[0], " AS ")
}
halves[0] = strings.TrimSpace(halves[0])
if len(halves) == 2 {
outCol.Alias = strings.TrimSpace(halves[1])
}
if halves[0][len(halves[0])-1] == ')' {
outCol.Type = TokenFunc
} else if halves[0] == "?" {
outCol.Type = TokenSub
} else {
outCol.Type = TokenColumn
}
outCol.Left = halves[0]
columns = append(columns, outCol)
}
return columns
}
// TODO: Allow order by statements without a direction
func processOrderby(orderStr string) (order []DBOrder) {
if orderStr == "" {
return order
}
for _, segment := range strings.Split(orderStr, ",") {
var outOrder DBOrder
halves := strings.Split(strings.TrimSpace(segment), " ")
if len(halves) != 2 {
continue
}
outOrder.Column = halves[0]
outOrder.Order = strings.ToLower(halves[1])
order = append(order, outOrder)
}
return order
}
func processJoiner(joinStr string) (joiner []DBJoiner) {
if joinStr == "" {
return joiner
}
joinStr = strings.Replace(joinStr, " on ", " ON ", -1)
joinStr = strings.Replace(joinStr, " and ", " AND ", -1)
for _, segment := range strings.Split(joinStr, " AND ") {
var outJoin DBJoiner
var parseOffset int
var left, right string
left, parseOffset = getIdentifier(segment, parseOffset)
outJoin.Operator, parseOffset = getOperator(segment, parseOffset+1)
right, parseOffset = getIdentifier(segment, parseOffset+1)
leftColumn := strings.Split(left, ".")
rightColumn := strings.Split(right, ".")
outJoin.LeftTable = strings.TrimSpace(leftColumn[0])
outJoin.RightTable = strings.TrimSpace(rightColumn[0])
outJoin.LeftColumn = strings.TrimSpace(leftColumn[1])
outJoin.RightColumn = strings.TrimSpace(rightColumn[1])
joiner = append(joiner, outJoin)
}
return joiner
}
func (wh *DBWhere) parseNumber(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if '0' <= ch && ch <= '9' {
//buffer += string(ch)
l++
} else {
i--
var str string
if l != 0 {
str = seg[si : si+l]
}
wh.Expr = append(wh.Expr, DBToken{str, TokenNumber})
return i
}
}
return i
}
func (wh *DBWhere) parseColumn(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
switch {
case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '.' || ch == '_':
//buffer += string(ch)
l++
case ch == '(':
var str string
if l != 0 {
str = seg[si : si+l]
}
return wh.parseFunction(seg, str, i)
default:
i--
var str string
if l != 0 {
str = seg[si : si+l]
}
wh.Expr = append(wh.Expr, DBToken{str, TokenColumn})
return i
}
}
return i
}
func (wh *DBWhere) parseFunction(seg string, buffer string, i int) int {
preI := i
i = skipFunctionCall(seg, i-1)
buffer += seg[preI:i] + string(seg[i])
wh.Expr = append(wh.Expr, DBToken{buffer, TokenFunc})
return i
}
func (wh *DBWhere) parseString(seg string, i int) int {
//var buffer string
i++
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if ch != '\'' {
//buffer += string(ch)
l++
} else {
var str string
if l != 0 {
str = seg[si : si+l]
}
wh.Expr = append(wh.Expr, DBToken{str, TokenString})
return i
}
}
return i
}
func (wh *DBWhere) parseOperator(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if isOpByte(ch) {
//buffer += string(ch)
l++
} else {
i--
var str string
if l != 0 {
str = seg[si : si+l]
}
wh.Expr = append(wh.Expr, DBToken{str, TokenOp})
return i
}
}
return i
}
// TODO: Make this case insensitive
func normalizeAnd(in string) string {
in = strings.Replace(in, " and ", " AND ", -1)
return strings.Replace(in, " && ", " AND ", -1)
}
func normalizeOr(in string) string {
in = strings.Replace(in, " or ", " OR ", -1)
return strings.Replace(in, " || ", " OR ", -1)
}
// TODO: Write tests for this
func processWhere(whereStr string) (where []DBWhere) {
if whereStr == "" {
return where
}
whereStr = normalizeAnd(whereStr)
whereStr = normalizeOr(whereStr)
for _, seg := range strings.Split(whereStr, " AND ") {
tmpWhere := &DBWhere{[]DBToken{}}
seg += ")"
for i := 0; i < len(seg); i++ {
ch := seg[i]
switch {
case '0' <= ch && ch <= '9':
i = tmpWhere.parseNumber(seg, i)
// TODO: Sniff the third byte offset from char or it's non-existent to avoid matching uppercase strings which start with OR
case ch == 'O' && (i+1) < len(seg) && seg[i+1] == 'R':
tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"OR", TokenOr})
i += 1
case ch == 'N' && (i+2) < len(seg) && seg[i+1] == 'O' && seg[i+2] == 'T':
tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"NOT", TokenNot})
i += 2
case ch == 'L' && (i+3) < len(seg) && seg[i+1] == 'I' && seg[i+2] == 'K' && seg[i+3] == 'E':
tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"LIKE", TokenLike})
i += 3
case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
i = tmpWhere.parseColumn(seg, i)
case ch == '\'':
i = tmpWhere.parseString(seg, i)
case ch == ')' && i < (len(seg)-1):
tmpWhere.Expr = append(tmpWhere.Expr, DBToken{")", TokenOp})
case isOpByte(ch):
i = tmpWhere.parseOperator(seg, i)
case ch == '?':
tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"?", TokenSub})
}
}
where = append(where, *tmpWhere)
}
return where
}
func (set *DBSetter) parseNumber(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if '0' <= ch && ch <= '9' {
//buffer += string(ch)
l++
} else {
var str string
if l != 0 {
str = seg[si : si+l]
}
set.Expr = append(set.Expr, DBToken{str, TokenNumber})
return i
}
}
return i
}
func (set *DBSetter) parseColumn(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
switch {
case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
//buffer += string(ch)
l++
case ch == '(':
var str string
if l != 0 {
str = seg[si : si+l]
}
return set.parseFunction(seg, str, i)
default:
i--
var str string
if l != 0 {
str = seg[si : si+l]
}
set.Expr = append(set.Expr, DBToken{str, TokenColumn})
return i
}
}
return i
}
func (set *DBSetter) parseFunction(segment string, buffer string, i int) int {
preI := i
i = skipFunctionCall(segment, i-1)
buffer += segment[preI:i] + string(segment[i])
set.Expr = append(set.Expr, DBToken{buffer, TokenFunc})
return i
}
func (set *DBSetter) parseString(seg string, i int) int {
//var buffer string
i++
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if ch != '\'' {
//buffer += string(ch)
l++
} else {
var str string
if l != 0 {
str = seg[si : si+l]
}
set.Expr = append(set.Expr, DBToken{str, TokenString})
return i
}
}
return i
}
func (set *DBSetter) parseOperator(seg string, i int) int {
//var buffer string
si := i
l := 0
for ; i < len(seg); i++ {
ch := seg[i]
if isOpByte(ch) {
//buffer += string(ch)
l++
} else {
i--
var str string
if l != 0 {
str = seg[si : si+l]
}
set.Expr = append(set.Expr, DBToken{str, TokenOp})
return i
}
}
return i
}
func processSet(setstr string) (setter []DBSetter) {
if setstr == "" {
return setter
}
// First pass, splitting the string by commas while ignoring the innards of functions
var setset []string
var buffer string
var lastItem int
setstr += ","
for i := 0; i < len(setstr); i++ {
if setstr[i] == '(' {
i = skipFunctionCall(setstr, i-1)
setset = append(setset, setstr[lastItem:i+1])
buffer = ""
lastItem = i + 2
} else if setstr[i] == ',' && buffer != "" {
setset = append(setset, buffer)
buffer = ""
lastItem = i + 1
} else if (setstr[i] > 32) && setstr[i] != ',' && setstr[i] != ')' {
buffer += string(setstr[i])
}
}
// Second pass. Break this setitem into manageable chunks
for _, setitem := range setset {
halves := strings.Split(setitem, "=")
if len(halves) != 2 {
continue
}
tmpSetter := &DBSetter{Column: strings.TrimSpace(halves[0])}
segment := halves[1] + ")"
for i := 0; i < len(segment); i++ {
ch := segment[i]
switch {
case '0' <= ch && ch <= '9':
i = tmpSetter.parseNumber(segment, i)
case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
i = tmpSetter.parseColumn(segment, i)
case ch == '\'':
i = tmpSetter.parseString(segment, i)
case isOpByte(ch):
i = tmpSetter.parseOperator(segment, i)
case ch == '?':
tmpSetter.Expr = append(tmpSetter.Expr, DBToken{"?", TokenSub})
}
}
setter = append(setter, *tmpSetter)
}
return setter
}
func processLimit(limitStr string) (limit DBLimit) {
halves := strings.Split(limitStr, ",")
if len(halves) == 2 {
limit.Offset = halves[0]
limit.MaxCount = halves[1]
} else {
limit.MaxCount = halves[0]
}
return limit
}
func isOpByte(ch byte) bool {
return ch == '<' || ch == '>' || ch == '=' || ch == '!' || ch == '*' || ch == '%' || ch == '+' || ch == '-' || ch == '/' || ch == '(' || ch == ')'
}
func isOpRune(ch rune) bool {
return ch == '<' || ch == '>' || ch == '=' || ch == '!' || ch == '*' || ch == '%' || ch == '+' || ch == '-' || ch == '/' || ch == '(' || ch == ')'
}
func processFields(fieldStr string) (fields []DBField) {
if fieldStr == "" {
return fields
}
var buffer string
var lastItem int
fieldStr += ","
for i := 0; i < len(fieldStr); i++ {
ch := fieldStr[i]
if ch == '(' {
i = skipFunctionCall(fieldStr, i-1)
fields = append(fields, DBField{Name: fieldStr[lastItem : i+1], Type: getIdentifierType(fieldStr[lastItem : i+1])})
buffer = ""
lastItem = i + 2
} else if ch == ',' && buffer != "" {
fields = append(fields, DBField{Name: buffer, Type: getIdentifierType(buffer)})
buffer = ""
lastItem = i + 1
} else if (ch >= 32) && ch != ',' && ch != ')' {
buffer += string(ch)
}
}
return fields
}
func getIdentifierType(iden string) int {
if ('a' <= iden[0] && iden[0] <= 'z') || ('A' <= iden[0] && iden[0] <= 'Z') {
if iden[len(iden)-1] == ')' {
return IdenFunc
}
return IdenColumn
}
if iden[0] == '\'' || iden[0] == '"' {
return IdenString
}
return IdenLiteral
}
func getIdentifier(seg string, startOffset int) (out string, i int) {
seg = strings.TrimSpace(seg)
seg += " " // Avoid overflow bugs with slicing
for i = startOffset; i < len(seg); i++ {
ch := seg[i]
if ch == '(' {
i = skipFunctionCall(seg, i)
return strings.TrimSpace(seg[startOffset:i]), (i - 1)
}
if (ch == ' ' || isOpByte(ch)) && i != startOffset {
return strings.TrimSpace(seg[startOffset:i]), (i - 1)
}
}
return strings.TrimSpace(seg[startOffset:]), (i - 1)
}
func getOperator(seg string, startOffset int) (out string, i int) {
seg = strings.TrimSpace(seg)
seg += " " // Avoid overflow bugs with slicing
for i = startOffset; i < len(seg); i++ {
if !isOpByte(seg[i]) && i != startOffset {
return strings.TrimSpace(seg[startOffset:i]), (i - 1)
}
}
return strings.TrimSpace(seg[startOffset:]), (i - 1)
}
func skipFunctionCall(data string, index int) int {
var braceCount int
for ; index < len(data); index++ {
char := data[index]
if char == '(' {
braceCount++
} else if char == ')' {
braceCount--
if braceCount == 0 {
return index
}
}
}
return index
}
func writeFile(name, content string) (err error) {
f, err := os.Create(name)
if err != nil {
return err
}
_, err = f.WriteString(content)
if err != nil {
return err
}
err = f.Sync()
if err != nil {
return err
}
return f.Close()
}