diff --git a/db/create.go b/db/create.go index ed0b0fd4..0db2bd2f 100644 --- a/db/create.go +++ b/db/create.go @@ -15,6 +15,7 @@ package db import ( + "fmt" "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/item" @@ -22,30 +23,41 @@ import ( "github.com/abcum/surreal/util/uuid" ) -func executeCreateStatement(txn kvs.TX, ast *sql.CreateStatement) (out []interface{}, err error) { +func (e *executor) executeCreateStatement(txn kvs.TX, ast *sql.CreateStatement) (out []interface{}, err error) { + + for k, w := range ast.What { + if what, ok := w.(*sql.Param); ok { + ast.What[k] = e.ctx.Get(what.ID).Data() + } + } for _, w := range ast.What { - if what, ok := w.(*sql.Thing); ok { + switch what := w.(type) { + + default: + return out, fmt.Errorf("Can not execute CREATE query using type '%T'", what) + + case *sql.Thing: key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: what.ID} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := create(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } - } - if what, ok := w.(*sql.Table); ok { + case *sql.Table: key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: uuid.NewV5(uuid.NewV4().UUID, ast.KV).String()} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := create(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } + } } diff --git a/db/db.go b/db/db.go index e845ca6c..16da9a53 100644 --- a/db/db.go +++ b/db/db.go @@ -25,12 +25,35 @@ import ( "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/log" "github.com/abcum/surreal/sql" + "github.com/abcum/surreal/util/data" _ "github.com/abcum/surreal/kvs/boltdb" _ "github.com/abcum/surreal/kvs/mysql" _ "github.com/abcum/surreal/kvs/pgsql" ) +type executor struct { + txn kvs.TX + ctx *data.Doc + ast *sql.Query +} + +func newExecutor(ast *sql.Query, vars map[string]interface{}) *executor { + return &executor{ast: ast, ctx: data.Consume(vars)} +} + +func (e *executor) Txn() kvs.TX { + return e.txn +} + +func (e *executor) Set(key string, val interface{}) { + e.ctx.Set(val, key) +} + +func (e *executor) Get(key string) (val interface{}) { + return e.ctx.Get(key).Data() +} + type Response struct { Time string `codec:"time,omitempty"` Status string `codec:"status,omitempty"` @@ -63,16 +86,40 @@ func Exit() { // Execute parses the query and executes it against the data layer func Execute(ctx *fibre.Context, txt interface{}, vars map[string]interface{}) (out []*Response, err error) { + // Parse the received SQL batch query strings + // into SQL ASTs, using any immutable preset + // variables if set. + ast, err := sql.Parse(ctx, txt, vars) if err != nil { return } + // If no preset variables have been defined + // then ensure that the variables is + // instantiated for future use. + + if vars == nil { + vars = make(map[string]interface{}) + } + + // Create 2 channels, one for force quitting + // the query processor, and the other for + // receiving and buffering any query results. + quit := make(chan bool, 1) recv := make(chan *Response) + // Ensure that the force quit channel is auto + // closed when the end of the request has been + // reached, and we are not an http connection. + defer close(quit) + // If the current connection is a normal http + // connection then force quit any running + // queries if a socket close event occurs. + if _, ok := ctx.Response().ResponseWriter.(http.CloseNotifier); ok { exit := ctx.Response().CloseNotify() @@ -89,7 +136,15 @@ func Execute(ctx *fibre.Context, txt interface{}, vars map[string]interface{}) ( } - go execute(ctx, ast, quit, recv) + // Create a new query executor with the query + // details, and the current runtime variables + // and execute the queries within. + + go newExecutor(ast, vars).execute(quit, recv) + + // Wait for all of the processed queries to + // return results, buffer the output, and + // return the output when finished. for res := range recv { out = append(out, res) @@ -99,6 +154,242 @@ func Execute(ctx *fibre.Context, txt interface{}, vars map[string]interface{}) ( } +func (e *executor) execute(quit <-chan bool, send chan<- *Response) { + + var err error + var txn kvs.TX + var rsp *Response + var buf []*Response + var res []interface{} + + // Ensure that the query responses channel is + // closed when the full query has been processed + // and dealt with. + + defer close(send) + + // If we are making use of a global transaction + // which is not committed at the end of the + // query set, then cancel the transaction. + + defer func() { + if txn != nil { + txn.Cancel() + } + }() + + // If we have paniced during query execution + // then ensure that we recover from the error + // and print the error to the log. + + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + fmt.Println(err) + } + } + }() + + stms := make(chan sql.Statement) + + // Loop over the defined query statements and + // pass them to the statement processing + // channel for execution. + + go func() { + for _, stm := range e.ast.Statements { + stms <- stm + } + close(stms) + }() + + // Listen for any new statements to process and + // at the same time listen for the quit signal + // notifying us whether the client has gone away. + + for { + + select { + + case <-quit: + return + + case stm, open := <-stms: + + // If we have reached the end of the statement + // processing channel then return out of this + // for loop and exit. + + if !open { + return + } + + // If we are not inside a global transaction + // then reset the error to nil so that the + // next statement is not ignored. + + if txn == nil { + err = nil + } + + // Check to see if the current statement is + // a TRANSACTION statement, and if it is + // then deal with it and move on to the next. + + switch stm.(type) { + case *sql.BeginStatement: + txn, err = begin(txn) + continue + case *sql.CancelStatement: + txn, err, buf = cancel(txn, buf, err, send) + continue + case *sql.CommitStatement: + txn, err, buf = commit(txn, buf, err, send) + continue + } + + // This is not a TRANSACTION statement and + // therefore we must time the execution speed + // and process the statement response. + + now := time.Now() + + // If an error has occured and we are inside + // a global transaction, then ignore all + // subsequent statements in the transaction. + + if err == nil { + res, err = e.operate(txn, stm) + } else { + res, err = []interface{}{}, fmt.Errorf("Query not executed") + } + + rsp = &Response{ + Time: time.Since(now).String(), + Status: status(err), + Detail: detail(err), + Result: append([]interface{}{}, res...), + } + + // If we are not inside a global transaction + // then we can output the statement response + // immediately to the channel. + + if txn == nil { + send <- rsp + continue + } + + // If we are inside a global transaction we + // must buffer the responses for output at + // the end of the transaction. + + if txn != nil { + buf = append(buf, rsp) + continue + } + + } + + } + +} + +func (e *executor) operate(txn kvs.TX, ast sql.Statement) (res []interface{}, err error) { + + var loc bool + + // If we are not inside a global transaction + // then grab a new transaction, ensuring that + // it is closed at the end. + + if txn == nil { + + loc = true + + switch ast.(type) { + case *sql.InfoStatement: + txn, err = readable() + default: + txn, err = writable() + } + + if err != nil { + return + } + + defer txn.Close() + + } + + // Execute the defined statement, receiving the + // result set, and any errors which occured + // while processing the query. + + switch stm := ast.(type) { + + case *sql.InfoStatement: + res, err = e.executeInfoStatement(txn, stm) + + case *sql.LetStatement: + res, err = e.executeLetStatement(txn, stm) + + case *sql.SelectStatement: + res, err = e.executeSelectStatement(txn, stm) + case *sql.CreateStatement: + res, err = e.executeCreateStatement(txn, stm) + case *sql.UpdateStatement: + res, err = e.executeUpdateStatement(txn, stm) + case *sql.ModifyStatement: + res, err = e.executeModifyStatement(txn, stm) + case *sql.DeleteStatement: + res, err = e.executeDeleteStatement(txn, stm) + case *sql.RelateStatement: + res, err = e.executeRelateStatement(txn, stm) + + case *sql.DefineScopeStatement: + res, err = e.executeDefineScopeStatement(txn, stm) + case *sql.RemoveScopeStatement: + res, err = e.executeRemoveScopeStatement(txn, stm) + + case *sql.DefineTableStatement: + res, err = e.executeDefineTableStatement(txn, stm) + case *sql.RemoveTableStatement: + res, err = e.executeRemoveTableStatement(txn, stm) + + case *sql.DefineRulesStatement: + res, err = e.executeDefineRulesStatement(txn, stm) + case *sql.RemoveRulesStatement: + res, err = e.executeRemoveRulesStatement(txn, stm) + + case *sql.DefineFieldStatement: + res, err = e.executeDefineFieldStatement(txn, stm) + case *sql.RemoveFieldStatement: + res, err = e.executeRemoveFieldStatement(txn, stm) + + case *sql.DefineIndexStatement: + res, err = e.executeDefineIndexStatement(txn, stm) + case *sql.RemoveIndexStatement: + res, err = e.executeRemoveIndexStatement(txn, stm) + + } + + // If this is a local transaction for only the + // current statement, then commit or cancel + // depending on the result error. + + if loc { + if err != nil { + txn.Cancel() + } else { + txn.Commit() + } + } + + return + +} + func status(e error) (s string) { switch e.(type) { default: @@ -125,6 +416,14 @@ func detail(e error) (s string) { } } +func clear(buf []*Response, rsp *Response) []*Response { + for i := len(buf) - 1; i >= 0; i-- { + buf[len(buf)-1] = nil + buf = buf[:len(buf)-1] + } + return append(buf, rsp) +} + func begin(txn kvs.TX) (tmp kvs.TX, err error) { if txn == nil { txn, err = writable() @@ -189,236 +488,3 @@ func writable() (txn kvs.TX, err error) { func readable() (txn kvs.TX, err error) { return db.Txn(false) } - -func execute(ctx *fibre.Context, ast *sql.Query, quit <-chan bool, send chan<- *Response) { - - var err error - var txn kvs.TX - var rsp *Response - var buf []*Response - var res []interface{} - - // Ensure that the query responses channel is - // closed when the full query has been processed - // and dealt with. - - defer close(send) - - // If we are making use of a global transaction - // which is not committed at the end of the - // query set, then cancel the transaction. - - defer func() { - if txn != nil { - txn.Cancel() - } - }() - - // If we have paniced during query execution - // then ensure that we recover from the error - // and print the error to the log. - - defer func() { - if r := recover(); r != nil { - if err, ok := r.(error); ok { - fmt.Println(err) - } - } - }() - - stms := make(chan sql.Statement) - - // Loop over the defined query statements and - // pass them to the statement processing - // channel for execution. - - go func() { - for _, stm := range ast.Statements { - stms <- stm - } - close(stms) - }() - - // Listen for any new statements to process and - // at the same time listen for the quit signal - // notifying us whether the client has gone away. - - for { - - select { - - case <-quit: - return - - case stm, open := <-stms: - - // If we have reached the end of the statement - // processing channel then return out of this - // for loop and exit. - - if !open { - return - } - - // If we are not inside a global transaction - // then reset the error to nil so that the - // next statement is not ignored. - - if txn == nil { - err = nil - } - - // Check to see if the current statement is - // a TRANSACTION statement, and if it is - // then deal with it and move on to the next. - - switch stm.(type) { - case *sql.BeginStatement: - txn, err = begin(txn) - continue - case *sql.CancelStatement: - txn, err, buf = cancel(txn, buf, err, send) - continue - case *sql.CommitStatement: - txn, err, buf = commit(txn, buf, err, send) - continue - } - - // This is not a TRANSACTION statement and - // therefore we must time the execution speed - // and process the statement response. - - now := time.Now() - - // If an error has occured and we are inside - // a global transaction, then ignore all - // subsequent statements in the transaction. - - if err == nil { - res, err = operate(txn, stm) - } else { - res, err = []interface{}{}, fmt.Errorf("Query not executed") - } - - rsp = &Response{ - Time: time.Since(now).String(), - Status: status(err), - Detail: detail(err), - Result: append([]interface{}{}, res...), - } - - // If we are not inside a global transaction - // then we can output the statement response - // immediately to the channel. - - if txn == nil { - send <- rsp - continue - } - - // If we are inside a global transaction we - // must buffer the responses for output at - // the end of the transaction. - - if txn != nil { - buf = append(buf, rsp) - continue - } - - } - - } - -} - -func operate(txn kvs.TX, ast sql.Statement) (res []interface{}, err error) { - - var loc bool - - // If we are not inside a global transaction - // then grab a new transaction, ensuring that - // it is closed at the end. - - if txn == nil { - - loc = true - - switch ast.(type) { - case *sql.InfoStatement: - txn, err = readable() - default: - txn, err = writable() - } - - if err != nil { - return - } - - defer txn.Close() - - } - - // Execute the defined statement, receiving the - // result set, and any errors which occured - // while processing the query. - - switch stm := ast.(type) { - - case *sql.InfoStatement: - res, err = executeInfoStatement(txn, stm) - - case *sql.SelectStatement: - res, err = executeSelectStatement(txn, stm) - case *sql.CreateStatement: - res, err = executeCreateStatement(txn, stm) - case *sql.UpdateStatement: - res, err = executeUpdateStatement(txn, stm) - case *sql.ModifyStatement: - res, err = executeModifyStatement(txn, stm) - case *sql.DeleteStatement: - res, err = executeDeleteStatement(txn, stm) - case *sql.RelateStatement: - res, err = executeRelateStatement(txn, stm) - - case *sql.DefineScopeStatement: - res, err = executeDefineScopeStatement(txn, stm) - case *sql.RemoveScopeStatement: - res, err = executeRemoveScopeStatement(txn, stm) - - case *sql.DefineTableStatement: - res, err = executeDefineTableStatement(txn, stm) - case *sql.RemoveTableStatement: - res, err = executeRemoveTableStatement(txn, stm) - - case *sql.DefineRulesStatement: - res, err = executeDefineRulesStatement(txn, stm) - case *sql.RemoveRulesStatement: - res, err = executeRemoveRulesStatement(txn, stm) - - case *sql.DefineFieldStatement: - res, err = executeDefineFieldStatement(txn, stm) - case *sql.RemoveFieldStatement: - res, err = executeRemoveFieldStatement(txn, stm) - - case *sql.DefineIndexStatement: - res, err = executeDefineIndexStatement(txn, stm) - case *sql.RemoveIndexStatement: - res, err = executeRemoveIndexStatement(txn, stm) - - } - - // If this is a local transaction for only the - // current statement, then commit or cancel - // depending on the result error. - - if loc { - if err != nil { - txn.Cancel() - } else { - txn.Commit() - } - } - - return - -} diff --git a/db/define.go b/db/define.go index 3aaaacae..398ff59d 100644 --- a/db/define.go +++ b/db/define.go @@ -22,7 +22,7 @@ import ( "github.com/abcum/surreal/util/pack" ) -func executeDefineScopeStatement(txn kvs.TX, ast *sql.DefineScopeStatement) (out []interface{}, err error) { +func (e *executor) executeDefineScopeStatement(txn kvs.TX, ast *sql.DefineScopeStatement) (out []interface{}, err error) { // Set the namespace definition nkey := &keys.NS{KV: ast.KV, NS: ast.NS} @@ -46,7 +46,7 @@ func executeDefineScopeStatement(txn kvs.TX, ast *sql.DefineScopeStatement) (out } -func executeDefineTableStatement(txn kvs.TX, ast *sql.DefineTableStatement) (out []interface{}, err error) { +func (e *executor) executeDefineTableStatement(txn kvs.TX, ast *sql.DefineTableStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -74,7 +74,7 @@ func executeDefineTableStatement(txn kvs.TX, ast *sql.DefineTableStatement) (out } -func executeDefineRulesStatement(txn kvs.TX, ast *sql.DefineRulesStatement) (out []interface{}, err error) { +func (e *executor) executeDefineRulesStatement(txn kvs.TX, ast *sql.DefineRulesStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -112,7 +112,7 @@ func executeDefineRulesStatement(txn kvs.TX, ast *sql.DefineRulesStatement) (out } -func executeDefineFieldStatement(txn kvs.TX, ast *sql.DefineFieldStatement) (out []interface{}, err error) { +func (e *executor) executeDefineFieldStatement(txn kvs.TX, ast *sql.DefineFieldStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -146,7 +146,7 @@ func executeDefineFieldStatement(txn kvs.TX, ast *sql.DefineFieldStatement) (out } -func executeDefineIndexStatement(txn kvs.TX, ast *sql.DefineIndexStatement) (out []interface{}, err error) { +func (e *executor) executeDefineIndexStatement(txn kvs.TX, ast *sql.DefineIndexStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -186,7 +186,7 @@ func executeDefineIndexStatement(txn kvs.TX, ast *sql.DefineIndexStatement) (out iend := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: TB, ID: keys.Suffix} kvs, _ := txn.RGet(ibeg.Encode(), iend.Encode(), 0) for _, kv := range kvs { - doc := item.New(kv, txn, nil) + doc := item.New(kv, txn, nil, e.ctx) if err := doc.StoreIndex(); err != nil { return nil, err } diff --git a/db/delete.go b/db/delete.go index a4ce3911..475a8146 100644 --- a/db/delete.go +++ b/db/delete.go @@ -15,39 +15,51 @@ package db import ( + "fmt" "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/item" "github.com/abcum/surreal/util/keys" ) -func executeDeleteStatement(txn kvs.TX, ast *sql.DeleteStatement) (out []interface{}, err error) { +func (e *executor) executeDeleteStatement(txn kvs.TX, ast *sql.DeleteStatement) (out []interface{}, err error) { + + for k, w := range ast.What { + if what, ok := w.(*sql.Param); ok { + ast.What[k] = e.ctx.Get(what.ID).Data() + } + } for _, w := range ast.What { - if what, ok := w.(*sql.Thing); ok { + switch what := w.(type) { + + default: + return out, fmt.Errorf("Can not execute DELETE query using type '%T'", what) + + case *sql.Thing: key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: what.ID} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := delete(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } - } - if what, ok := w.(*sql.Table); ok { + case *sql.Table: beg := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Prefix} end := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Suffix} kvs, _ := txn.RGet(beg.Encode(), end.Encode(), 0) for _, kv := range kvs { - doc := item.New(kv, txn, nil) + doc := item.New(kv, txn, nil, e.ctx) if ret, err := delete(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } } + } } diff --git a/db/info.go b/db/info.go index c6c00397..93fa6ae0 100644 --- a/db/info.go +++ b/db/info.go @@ -23,7 +23,7 @@ import ( "github.com/abcum/surreal/util/pack" ) -func executeInfoStatement(txn kvs.TX, ast *sql.InfoStatement) (out []interface{}, err error) { +func (e *executor) executeInfoStatement(txn kvs.TX, ast *sql.InfoStatement) (out []interface{}, err error) { if ast.What == "" { diff --git a/db/let.go b/db/let.go new file mode 100644 index 00000000..09d71220 --- /dev/null +++ b/db/let.go @@ -0,0 +1,33 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "github.com/abcum/surreal/kvs" + "github.com/abcum/surreal/sql" +) + +func (e *executor) executeLetStatement(txn kvs.TX, ast *sql.LetStatement) (out []interface{}, err error) { + + switch expr := ast.Expr.(type) { + default: + e.Set(ast.Name, expr) + case *sql.Param: + e.Set(ast.Name, e.Get(expr.ID)) + } + + return + +} diff --git a/db/modify.go b/db/modify.go index 5280dad5..21ce19f0 100644 --- a/db/modify.go +++ b/db/modify.go @@ -15,39 +15,51 @@ package db import ( + "fmt" "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/item" "github.com/abcum/surreal/util/keys" ) -func executeModifyStatement(txn kvs.TX, ast *sql.ModifyStatement) (out []interface{}, err error) { +func (e *executor) executeModifyStatement(txn kvs.TX, ast *sql.ModifyStatement) (out []interface{}, err error) { + + for k, w := range ast.What { + if what, ok := w.(*sql.Param); ok { + ast.What[k] = e.ctx.Get(what.ID).Data() + } + } for _, w := range ast.What { - if what, ok := w.(*sql.Thing); ok { + switch what := w.(type) { + + default: + return out, fmt.Errorf("Can not execute MODIFY query using type '%T'", what) + + case *sql.Thing: key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: what.ID} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := modify(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } - } - if what, ok := w.(*sql.Table); ok { + case *sql.Table: beg := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Prefix} end := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Suffix} kvs, _ := txn.RGet(beg.Encode(), end.Encode(), 0) for _, kv := range kvs { - doc := item.New(kv, txn, nil) + doc := item.New(kv, txn, nil, e.ctx) if ret, err := modify(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } } + } } diff --git a/db/relate.go b/db/relate.go index eef4c265..42dd9d03 100644 --- a/db/relate.go +++ b/db/relate.go @@ -19,6 +19,6 @@ import ( "github.com/abcum/surreal/sql" ) -func executeRelateStatement(txn kvs.TX, ast *sql.RelateStatement) ([]interface{}, error) { +func (e *executor) executeRelateStatement(txn kvs.TX, ast *sql.RelateStatement) (out []interface{}, err error) { return nil, nil } diff --git a/db/remove.go b/db/remove.go index 877bfe8f..ef4c6416 100644 --- a/db/remove.go +++ b/db/remove.go @@ -20,7 +20,7 @@ import ( "github.com/abcum/surreal/util/keys" ) -func executeRemoveScopeStatement(txn kvs.TX, ast *sql.RemoveScopeStatement) (out []interface{}, err error) { +func (e *executor) executeRemoveScopeStatement(txn kvs.TX, ast *sql.RemoveScopeStatement) (out []interface{}, err error) { // Remove the scope config skey := &keys.SC{KV: ast.KV, NS: ast.NS, DB: ast.DB, SC: ast.Name} @@ -32,7 +32,7 @@ func executeRemoveScopeStatement(txn kvs.TX, ast *sql.RemoveScopeStatement) (out } -func executeRemoveTableStatement(txn kvs.TX, ast *sql.RemoveTableStatement) (out []interface{}, err error) { +func (e *executor) executeRemoveTableStatement(txn kvs.TX, ast *sql.RemoveTableStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -72,7 +72,7 @@ func executeRemoveTableStatement(txn kvs.TX, ast *sql.RemoveTableStatement) (out } -func executeRemoveRulesStatement(txn kvs.TX, ast *sql.RemoveRulesStatement) (out []interface{}, err error) { +func (e *executor) executeRemoveRulesStatement(txn kvs.TX, ast *sql.RemoveRulesStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -92,7 +92,7 @@ func executeRemoveRulesStatement(txn kvs.TX, ast *sql.RemoveRulesStatement) (out } -func executeRemoveFieldStatement(txn kvs.TX, ast *sql.RemoveFieldStatement) (out []interface{}, err error) { +func (e *executor) executeRemoveFieldStatement(txn kvs.TX, ast *sql.RemoveFieldStatement) (out []interface{}, err error) { for _, TB := range ast.What { @@ -108,7 +108,7 @@ func executeRemoveFieldStatement(txn kvs.TX, ast *sql.RemoveFieldStatement) (out } -func executeRemoveIndexStatement(txn kvs.TX, ast *sql.RemoveIndexStatement) (out []interface{}, err error) { +func (e *executor) executeRemoveIndexStatement(txn kvs.TX, ast *sql.RemoveIndexStatement) (out []interface{}, err error) { for _, TB := range ast.What { diff --git a/db/select.go b/db/select.go index 7a678ec5..64ece8e6 100644 --- a/db/select.go +++ b/db/select.go @@ -21,14 +21,20 @@ import ( "github.com/abcum/surreal/util/keys" ) -func executeSelectStatement(txn kvs.TX, ast *sql.SelectStatement) (out []interface{}, err error) { +func (e *executor) executeSelectStatement(txn kvs.TX, ast *sql.SelectStatement) (out []interface{}, err error) { + + for k, w := range ast.What { + if what, ok := w.(*sql.Param); ok { + ast.What[k] = e.ctx.Get(what.ID).Data() + } + } for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: what.ID} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := detect(doc, ast); err != nil { return nil, err } else if ret != nil { @@ -41,7 +47,7 @@ func executeSelectStatement(txn kvs.TX, ast *sql.SelectStatement) (out []interfa end := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Suffix} kvs, _ := txn.RGet(beg.Encode(), end.Encode(), 0) for _, kv := range kvs { - doc := item.New(kv, txn, nil) + doc := item.New(kv, txn, nil, e.ctx) if ret, err := detect(doc, ast); err != nil { return nil, err } else if ret != nil { diff --git a/db/update.go b/db/update.go index e64a534b..a6abf3ec 100644 --- a/db/update.go +++ b/db/update.go @@ -15,39 +15,51 @@ package db import ( + "fmt" "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/item" "github.com/abcum/surreal/util/keys" ) -func executeUpdateStatement(txn kvs.TX, ast *sql.UpdateStatement) (out []interface{}, err error) { +func (e *executor) executeUpdateStatement(txn kvs.TX, ast *sql.UpdateStatement) (out []interface{}, err error) { + + for k, w := range ast.What { + if what, ok := w.(*sql.Param); ok { + ast.What[k] = e.ctx.Get(what.ID).Data() + } + } for _, w := range ast.What { - if what, ok := w.(*sql.Thing); ok { + switch what := w.(type) { + + default: + return out, fmt.Errorf("Can not execute UPDATE query using type '%T'", what) + + case *sql.Thing: key := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: what.ID} kv, _ := txn.Get(key.Encode()) - doc := item.New(kv, txn, key) + doc := item.New(kv, txn, key, e.ctx) if ret, err := update(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } - } - if what, ok := w.(*sql.Table); ok { + case *sql.Table: beg := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Prefix} end := &keys.Thing{KV: ast.KV, NS: ast.NS, DB: ast.DB, TB: what.TB, ID: keys.Suffix} kvs, _ := txn.RGet(beg.Encode(), end.Encode(), 0) for _, kv := range kvs { - doc := item.New(kv, txn, nil) + doc := item.New(kv, txn, nil, e.ctx) if ret, err := update(doc, ast); err != nil { return nil, err } else if ret != nil { out = append(out, ret) } } + } } diff --git a/sql/ast.go b/sql/ast.go index 26f6a351..ccbfd0eb 100644 --- a/sql/ast.go +++ b/sql/ast.go @@ -48,6 +48,11 @@ type CancelStatement struct{} // UseStatement represents a SQL COMMIT TRANSACTION statement. type CommitStatement struct{} +// ReturnStatement represents a SQL RETURN statement. +type ReturnStatement struct { + What []Expr +} + // -------------------------------------------------- // Use // -------------------------------------------------- @@ -70,6 +75,19 @@ type InfoStatement struct { What string `cork:"-" codec:"-"` } +// -------------------------------------------------- +// LET +// -------------------------------------------------- + +// LetStatement represents a SQL LET statement. +type LetStatement struct { + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + Name string `cork:"-" codec:"-"` + Expr Expr `cork:"-" codec:"-"` +} + // -------------------------------------------------- // Normal // -------------------------------------------------- @@ -418,6 +436,27 @@ func NewIdent(ID string) *Ident { // Parts // -------------------------------------------------- +// Param comment +type Param struct { + ID string +} + +func (this Param) String() string { + return this.ID +} + +func (this Param) MarshalText() (data []byte, err error) { + return []byte("ID:" + this.ID), err +} + +func NewParam(ID string) *Param { + return &Param{ID} +} + +// -------------------------------------------------- +// Parts +// -------------------------------------------------- + // Table comment type Table struct { TB string diff --git a/sql/exprs.go b/sql/exprs.go index ee78082b..91dc86da 100644 --- a/sql/exprs.go +++ b/sql/exprs.go @@ -39,10 +39,7 @@ func (p *parser) parseWhat() (mul []Expr, err error) { } if p.is(tok, PARAM) { - one, err := p.declare(PARAM, lit) - if err != nil { - return nil, err - } + one, _ := p.declare(PARAM, lit) mul = append(mul, one) } @@ -261,7 +258,7 @@ func (p *parser) parseExpr() (mul []*Field, err error) { one := &Field{} - tok, lit, err = p.shouldBe(IDENT, ID, NOW, PATH, NULL, ALL, TIME, TRUE, FALSE, STRING, REGION, NUMBER, DOUBLE, JSON, ARRAY) + tok, lit, err = p.shouldBe(IDENT, ID, NOW, PATH, NULL, ALL, TIME, TRUE, FALSE, STRING, REGION, NUMBER, DOUBLE, JSON, ARRAY, PARAM) if err != nil { return nil, &ParseError{Found: lit, Expected: []string{"field name"}} } diff --git a/sql/let.go b/sql/let.go new file mode 100644 index 00000000..b981f205 --- /dev/null +++ b/sql/let.go @@ -0,0 +1,51 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +func (p *parser) parseLetStatement() (stmt *LetStatement, err error) { + + stmt = &LetStatement{} + + stmt.KV = p.c.Get("KV").(string) + stmt.NS = p.c.Get("NS").(string) + stmt.DB = p.c.Get("DB").(string) + + _, stmt.Name, err = p.shouldBe(IDENT) + if err != nil { + return nil, err + } + + _, _, err = p.shouldBe(EQ) + if err != nil { + return nil, err + } + + tok, lit, err := p.shouldBe(NULL, NOW, DATE, TIME, TRUE, FALSE, STRING, NUMBER, DOUBLE, THING, JSON, ARRAY, PARAM) + if err != nil { + return nil, err + } + + stmt.Expr, err = p.declare(tok, lit) + if err != nil { + return nil, err + } + + if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil { + return nil, err + } + + return + +} diff --git a/sql/parser.go b/sql/parser.go index cf9448e4..726c8b6c 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -115,13 +115,16 @@ func (p *parser) parseMulti() (*Query, error) { // parseSingle parses a single SQL SELECT statement. func (p *parser) parseSingle() (Statement, error) { - tok, _, err := p.shouldBe(USE, INFO, LET, BEGIN, CANCEL, COMMIT, ROLLBACK, SELECT, CREATE, UPDATE, INSERT, UPSERT, MODIFY, DELETE, RELATE, DEFINE, REMOVE) + tok, _, err := p.shouldBe(USE, INFO, LET, BEGIN, CANCEL, COMMIT, ROLLBACK, RETURN, SELECT, CREATE, UPDATE, INSERT, UPSERT, MODIFY, DELETE, RELATE, DEFINE, REMOVE) switch tok { case USE: return p.parseUseStatement() + case LET: + return p.parseLetStatement() + case INFO: return p.parseInfoStatement() @@ -131,6 +134,8 @@ func (p *parser) parseSingle() (Statement, error) { return p.parseCancelStatement() case COMMIT: return p.parseCommitStatement() + case RETURN: + return p.parseReturnStatement() case SELECT: return p.parseSelectStatement() diff --git a/sql/util.go b/sql/util.go index 63c7901b..6b311a47 100644 --- a/sql/util.go +++ b/sql/util.go @@ -129,6 +129,12 @@ func (p *parser) declare(tok Token, lit string) (interface{}, error) { case DURATION: return time.ParseDuration(lit) + case PARAM: + if p, ok := p.v[lit]; ok { + return p, nil + } + return &Param{lit}, nil + case ARRAY: var j []interface{} json.Unmarshal([]byte(lit), &j) @@ -145,12 +151,6 @@ func (p *parser) declare(tok Token, lit string) (interface{}, error) { } return j, nil - case PARAM: - if p, ok := p.v[lit]; ok { - return p, nil - } - return nil, fmt.Errorf("Param %s is not defined", lit) - } return lit, nil diff --git a/util/item/blaze.go b/util/item/blaze.go index 3a1c8d98..c899dd7d 100644 --- a/util/item/blaze.go +++ b/util/item/blaze.go @@ -42,6 +42,8 @@ func (this *Doc) Blaze(ast *sql.SelectStatement) (res interface{}) { doc.Set(nil, v.Alias) case *sql.Thing: doc.Set(e.String(), v.Alias) + case *sql.Param: + doc.Set(this.runtime.Get(e.ID).Data(), v.Alias) case *sql.Ident: doc.Set(this.current.Get(e.ID).Data(), v.Alias) case *sql.All: diff --git a/util/item/check.go b/util/item/check.go index 70f1bfe9..de06ba34 100644 --- a/util/item/check.go +++ b/util/item/check.go @@ -42,8 +42,8 @@ func (this *Doc) Check(cond []sql.Expr) (val bool) { func (this *Doc) chkOne(expr *sql.BinaryExpression) (val bool) { op := expr.Op - lhs := getChkItem(this.current, expr.LHS) - rhs := getChkItem(this.current, expr.RHS) + lhs := this.getChk(expr.LHS) + rhs := this.getChk(expr.RHS) switch lhs.(type) { case bool, string, int64, float64, time.Time: @@ -616,7 +616,7 @@ func chkMatch(op sql.Token, a []interface{}, r *regexp.Regexp) (val bool) { } -func getChkItem(doc *data.Doc, expr sql.Expr) interface{} { +func (this *Doc) getChk(expr sql.Expr) interface{} { switch val := expr.(type) { default: @@ -635,8 +635,10 @@ func getChkItem(doc *data.Doc, expr sql.Expr) interface{} { return val case *sql.Empty: return val + case *sql.Param: + return this.runtime.Get(val.ID).Data() case *sql.Ident: - return doc.Get(val.ID).Data() + return this.current.Get(val.ID).Data() } } diff --git a/util/item/item.go b/util/item/item.go index 3074f6b5..db98ab64 100644 --- a/util/item/item.go +++ b/util/item/item.go @@ -32,14 +32,15 @@ type Doc struct { key *keys.Thing initial *data.Doc current *data.Doc + runtime *data.Doc fields []*sql.DefineFieldStatement indexs []*sql.DefineIndexStatement rules map[string]*sql.DefineRulesStatement } -func New(kv kvs.KV, tx kvs.TX, key *keys.Thing) (this *Doc) { +func New(kv kvs.KV, tx kvs.TX, key *keys.Thing, vars *data.Doc) (this *Doc) { - this = &Doc{kv: kv, key: key, tx: tx} + this = &Doc{kv: kv, key: key, tx: tx, runtime: vars} if key == nil { this.key = &keys.Thing{} diff --git a/util/item/merge.go b/util/item/merge.go index 2ed7b149..46c7a8d2 100644 --- a/util/item/merge.go +++ b/util/item/merge.go @@ -150,8 +150,8 @@ func (this *Doc) mrgDpm(expr *sql.DiffExpression) { func (this *Doc) mrgOne(expr *sql.BinaryExpression) { - lhs := getMrgItemLHS(this.current, expr.LHS) - rhs := getMrgItemRHS(this.current, expr.RHS) + lhs := this.getMrgItemLHS(expr.LHS) + rhs := this.getMrgItemRHS(expr.RHS) if expr.Op == sql.EQ { switch expr.RHS.(type) { @@ -172,7 +172,7 @@ func (this *Doc) mrgOne(expr *sql.BinaryExpression) { } -func getMrgItemLHS(doc *data.Doc, expr sql.Expr) string { +func (this *Doc) getMrgItemLHS(expr sql.Expr) string { switch val := expr.(type) { default: @@ -185,7 +185,7 @@ func getMrgItemLHS(doc *data.Doc, expr sql.Expr) string { } -func getMrgItemRHS(doc *data.Doc, expr sql.Expr) interface{} { +func (this *Doc) getMrgItemRHS(expr sql.Expr) interface{} { switch val := expr.(type) { default: @@ -198,8 +198,10 @@ func getMrgItemRHS(doc *data.Doc, expr sql.Expr) interface{} { return val case *sql.Thing: return val + case *sql.Param: + return this.runtime.Get(val.ID).Data() case *sql.Ident: - return doc.Get(val.ID).Data() + return this.current.Get(val.ID).Data() } }