surrealpatch/db/executor.go
Tobie Morgan Hitchcock 36e7d8ed3a Flush websocket notifications correctly
Websocket notifications were cleared/flushed regardless of whether individual statements were successful or not.

Now notifications are shifted onto the stack, or removed if the statement is unsuccessful. Once the full query has been processed, all pending notifications are flushed to all websockets (ignoring the current connection frin which th query originated).
2018-01-31 09:15:29 +00:00

524 lines
11 KiB
Go

// Copyright © 2016 Abcum Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package db
import (
"fmt"
"time"
"context"
"runtime/debug"
"cloud.google.com/go/trace"
"github.com/abcum/surreal/kvs"
"github.com/abcum/surreal/log"
"github.com/abcum/surreal/mem"
"github.com/abcum/surreal/sql"
)
type executor struct {
dbo *mem.Cache
send chan *Response
}
func newExecutor() (e *executor) {
e = executorPool.Get().(*executor)
e.dbo = mem.New()
e.send = make(chan *Response)
return
}
func (e *executor) execute(ctx context.Context, ast *sql.Query) {
var err error
var now time.Time
var rsp *Response
var buf []*Response
var res []interface{}
// Get the fibre context ID so that we can use
// it to clear or flush websocket notification
// changes linked to this context.
id := ctx.Value(ctxKeyId).(string)
// Ensure that the executor is added back into
// the executor pool when the executor has
// finished processing the request.
defer executorPool.Put(e)
// Ensure that the query responses channel is
// closed when the full query has been processed
// and dealt with.
defer close(e.send)
// If we are making use of a global transaction
// which is not committed at the end of the
// query set, then cancel the transaction.
defer func() {
if e.dbo.TX != nil {
e.dbo.Cancel()
clear(id)
}
}()
// If we have panicked during query execution
// then ensure that we recover from the error
// and print the error to the log.
defer func() {
if err := recover(); err != nil {
log.WithPrefix("db").WithFields(map[string]interface{}{
"id": ctx.Value(ctxKeyId), "stack": string(debug.Stack()),
}).Errorln(err)
}
}()
// Loop over the defined query statements and
// process them, while listening for the quit
// channel to see if the client has gone away.
for _, stm := range ast.Statements {
select {
case <-ctx.Done():
return
default:
trc := trace.FromContext(ctx).NewChild(fmt.Sprint(stm))
ctx := trace.NewContext(ctx, trc)
// When in debugging mode, log every sql
// query, along with the query execution
// speed, so we can analyse slow queries.
log := log.WithPrefix("sql").WithFields(map[string]interface{}{
"id": ctx.Value(ctxKeyId),
"kind": ctx.Value(ctxKeyKind),
"auth": ctx.Value(ctxKeyAuth),
})
if stm, ok := stm.(sql.AuthableStatement); ok {
ns, db := stm.Auth()
ctx = context.WithValue(ctx, ctxKeyNs, ns)
ctx = context.WithValue(ctx, ctxKeyDb, db)
log = log.WithField("ns", ns).WithField("db", db)
}
log.Debugln(stm)
// If we are not inside a global transaction
// then reset the error to nil so that the
// next statement is not ignored.
if e.dbo.TX == nil {
err, now = nil, time.Now()
}
// Check to see if the current statement is
// a TRANSACTION statement, and if it is
// then deal with it and move on to the next.
switch stm.(type) {
case *sql.BeginStatement:
err = e.begin(true)
trc.Finish()
continue
case *sql.CancelStatement:
err, buf = e.cancel(buf, err, e.send)
if err != nil {
clear(id)
} else {
clear(id)
}
trc.Finish()
continue
case *sql.CommitStatement:
err, buf = e.commit(buf, err, e.send)
if err != nil {
clear(id)
} else {
flush(id)
}
trc.Finish()
continue
}
// If an error has occured and we are inside
// a global transaction, then ignore all
// subsequent statements in the transaction.
if err == nil {
res, err = e.operate(ctx, stm)
} else {
res, err = []interface{}{}, queryNotExecuted
}
rsp = &Response{
Time: time.Since(now).String(),
Status: status(err),
Detail: detail(err),
Result: append([]interface{}{}, res...),
}
// If we are not inside a global transaction
// then we can output the statement response
// immediately to the channel.
if e.dbo.TX == nil {
e.send <- rsp
}
// If we are inside a global transaction we
// must buffer the responses for output at
// the end of the transaction.
if e.dbo.TX != nil {
switch stm.(type) {
case *sql.ReturnStatement:
buf = groupd(buf, rsp)
default:
buf = append(buf, rsp)
}
}
trc.Finish()
}
}
}
func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interface{}, err error) {
var loc bool
var trw bool
var canc context.CancelFunc
// If the statement is a UseStatement then
// there is no need to create a transaction
// as the query does not do anything.
if _, ok := stm.(*sql.UseStatement); ok {
return
}
// If we are not inside a global transaction
// then grab a new transaction, ensuring that
// it is closed at the end.
if e.dbo.TX == nil {
loc = true
switch stm := stm.(type) {
case sql.WriteableStatement:
trw = stm.Writeable()
default:
trw = false
}
err = e.begin(trw)
if err != nil {
return
}
defer e.dbo.Cancel()
}
// Mark the beginning of this statement so we
// can monitor the running time, and ensure
// it runs no longer than specified.
if stm, ok := stm.(sql.KillableStatement); ok {
if stm.Duration() > 0 {
ctx, canc = context.WithTimeout(ctx, stm.Duration())
defer func() {
if tim := ctx.Err(); err == nil && tim != nil {
res, err = nil, &TimerError{timer: stm.Duration()}
}
canc()
}()
}
}
// Get the fibre context ID so that we can use
// it to clear or flush websocket notification
// changes linked to this context.
id := ctx.Value(ctxKeyId).(string)
// Execute the defined statement, receiving the
// result set, and any errors which occured
// while processing the query.
switch stm := stm.(type) {
case *sql.IfStatement:
res, err = e.executeIf(ctx, stm)
case *sql.InfoStatement:
res, err = e.executeInfo(ctx, stm)
case *sql.LetStatement:
res, err = e.executeLet(ctx, stm)
case *sql.ReturnStatement:
res, err = e.executeReturn(ctx, stm)
case *sql.LiveStatement:
res, err = e.executeLive(ctx, stm)
case *sql.KillStatement:
res, err = e.executeKill(ctx, stm)
case *sql.SelectStatement:
res, err = e.executeSelect(ctx, stm)
case *sql.CreateStatement:
res, err = e.executeCreate(ctx, stm)
case *sql.UpdateStatement:
res, err = e.executeUpdate(ctx, stm)
case *sql.DeleteStatement:
res, err = e.executeDelete(ctx, stm)
case *sql.RelateStatement:
res, err = e.executeRelate(ctx, stm)
case *sql.InsertStatement:
res, err = e.executeInsert(ctx, stm)
case *sql.UpsertStatement:
res, err = e.executeUpsert(ctx, stm)
case *sql.DefineNamespaceStatement:
res, err = e.executeDefineNamespace(ctx, stm)
case *sql.RemoveNamespaceStatement:
res, err = e.executeRemoveNamespace(ctx, stm)
case *sql.DefineDatabaseStatement:
res, err = e.executeDefineDatabase(ctx, stm)
case *sql.RemoveDatabaseStatement:
res, err = e.executeRemoveDatabase(ctx, stm)
case *sql.DefineLoginStatement:
res, err = e.executeDefineLogin(ctx, stm)
case *sql.RemoveLoginStatement:
res, err = e.executeRemoveLogin(ctx, stm)
case *sql.DefineTokenStatement:
res, err = e.executeDefineToken(ctx, stm)
case *sql.RemoveTokenStatement:
res, err = e.executeRemoveToken(ctx, stm)
case *sql.DefineScopeStatement:
res, err = e.executeDefineScope(ctx, stm)
case *sql.RemoveScopeStatement:
res, err = e.executeRemoveScope(ctx, stm)
case *sql.DefineTableStatement:
res, err = e.executeDefineTable(ctx, stm)
case *sql.RemoveTableStatement:
res, err = e.executeRemoveTable(ctx, stm)
case *sql.DefineEventStatement:
res, err = e.executeDefineEvent(ctx, stm)
case *sql.RemoveEventStatement:
res, err = e.executeRemoveEvent(ctx, stm)
case *sql.DefineFieldStatement:
res, err = e.executeDefineField(ctx, stm)
case *sql.RemoveFieldStatement:
res, err = e.executeRemoveField(ctx, stm)
case *sql.DefineIndexStatement:
res, err = e.executeDefineIndex(ctx, stm)
case *sql.RemoveIndexStatement:
res, err = e.executeRemoveIndex(ctx, stm)
}
// If the context is already closed or errord,
// then ignore this result, clear all queued
// changes, and reset the transaction.
select {
case <-ctx.Done():
e.dbo.Cancel()
e.dbo.Reset()
clear(id)
default:
// If this is a local transaction for only the
// current statement, then commit or cancel
// depending on the result error.
if loc && e.dbo.Closed() == false {
// As this is a local transaction then
// make sure we reset the transaction
// context.
defer e.dbo.Reset()
// If there was an error with the query
// then clear the queued changes and
// return immediately.
if err != nil {
e.dbo.Cancel()
clear(id)
return
}
// Otherwise check if this is a read or
// a write transaction, and attempt to
// Cancel or Commit, returning any errors.
if !trw {
if err = e.dbo.Cancel(); err != nil {
clear(id)
} else {
clear(id)
}
} else {
if err = e.dbo.Commit(); err != nil {
clear(id)
} else {
shift(id)
}
}
}
}
return
}
func (e *executor) begin(rw bool) (err error) {
if e.dbo.TX == nil {
e.dbo.TX, err = db.Begin(rw)
}
return
}
func (e *executor) cancel(buf []*Response, err error, chn chan<- *Response) (error, []*Response) {
defer e.dbo.Reset()
if e.dbo.TX == nil {
return nil, buf
}
err = e.dbo.Cancel()
for _, v := range buf {
v.Status = "ERR"
v.Result = []interface{}{}
v.Detail = "Transaction cancelled"
chn <- v
}
for i := len(buf) - 1; i >= 0; i-- {
buf[len(buf)-1] = nil
buf = buf[:len(buf)-1]
}
return err, buf
}
func (e *executor) commit(buf []*Response, err error, chn chan<- *Response) (error, []*Response) {
defer e.dbo.Reset()
if e.dbo.TX == nil {
return nil, buf
}
if err != nil {
err = e.dbo.Cancel()
} else {
err = e.dbo.Commit()
}
for _, v := range buf {
if err != nil {
v.Status = "ERR"
v.Result = []interface{}{}
v.Detail = "Transaction failed: " + err.Error()
}
chn <- v
}
for i := len(buf) - 1; i >= 0; i-- {
buf[len(buf)-1] = nil
buf = buf[:len(buf)-1]
}
return err, buf
}
func status(e error) (s string) {
switch e.(type) {
default:
return "OK"
case *kvs.DBError:
return "ERR_DB"
case *kvs.KVError:
return "ERR_KV"
case *PermsError:
return "ERR_PE"
case *ExistError:
return "ERR_KV"
case *FieldError:
return "ERR_FD"
case *IndexError:
return "ERR_IX"
case error:
return "ERR"
}
}
func detail(e error) (s string) {
switch err := e.(type) {
default:
return
case error:
return err.Error()
}
}
func groupd(buf []*Response, rsp *Response) []*Response {
for i := len(buf) - 1; i >= 0; i-- {
buf[len(buf)-1] = nil
buf = buf[:len(buf)-1]
}
return append(buf, rsp)
}