Flush websocket notifications correctly

Websocket notifications were cleared/flushed regardless of whether individual statements were successful or not.

Now notifications are shifted onto the stack, or removed if the statement is unsuccessful. Once the full query has been processed, all pending notifications are flushed to all websockets (ignoring the current connection frin which th query originated).
This commit is contained in:
Tobie Morgan Hitchcock 2018-01-31 09:15:29 +00:00
parent a9883efc4a
commit 36e7d8ed3a
5 changed files with 105 additions and 29 deletions

View file

@ -28,6 +28,7 @@ import (
"github.com/abcum/surreal/sql" "github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/data" "github.com/abcum/surreal/util/data"
"github.com/abcum/surreal/util/uuid"
"cloud.google.com/go/trace" "cloud.google.com/go/trace"
@ -112,6 +113,14 @@ func Execute(fib *fibre.Context, txt interface{}, vars map[string]interface{}) (
vars = make(map[string]interface{}) vars = make(map[string]interface{})
} }
// Ensure that we have a unique id assigned
// to this fibre connection, as we need it
// to detect unique websocket notifications.
if fib.Get(ctxKeyId) == nil {
fib.Set(ctxKeyId, uuid.New().String())
}
// Ensure that the IP address of the // Ensure that the IP address of the
// user signing in is available so that // user signing in is available so that
// it can be used within signin queries. // it can be used within signin queries.
@ -165,6 +174,14 @@ func Process(fib *fibre.Context, ast *sql.Query, vars map[string]interface{}) (o
vars = make(map[string]interface{}) vars = make(map[string]interface{})
} }
// Ensure that we have a unique id assigned
// to this fibre connection, as we need it
// to detect unique websocket notifications.
if fib.Get(ctxKeyId) == nil {
fib.Set(ctxKeyId, uuid.New().String())
}
// Create a new context so that we can quit // Create a new context so that we can quit
// all goroutine workers if the http client // all goroutine workers if the http client
// itself is closed before finishing. // itself is closed before finishing.
@ -177,11 +194,17 @@ func Process(fib *fibre.Context, ast *sql.Query, vars map[string]interface{}) (o
defer quit() defer quit()
// Get the unique id for this connection
// so that we can assign it to the context
// and detect any websocket notifications.
id := fib.Get(ctxKeyId).(string)
// Assign the fibre request context id to // Assign the fibre request context id to
// the context so that we can log the id // the context so that we can log the id
// together with the request. // together with the request.
ctx = context.WithValue(ctx, ctxKeyId, fib.Get("id")) ctx = context.WithValue(ctx, ctxKeyId, id)
// Assign the authentication data to the // Assign the authentication data to the
// context so that we can log the auth kind // context so that we can log the auth kind
@ -225,6 +248,12 @@ func Process(fib *fibre.Context, ast *sql.Query, vars map[string]interface{}) (o
go executor.execute(ctx, ast) go executor.execute(ctx, ast)
// Ensure that we flush all websocket events
// once the query has been fully processed
// whilst ignoring this connection itself.
defer flush(id)
// Wait for all of the processed queries to // Wait for all of the processed queries to
// return results, buffer the output, and // return results, buffer the output, and
// return the output when finished. // return the output when finished.

View file

@ -55,6 +55,12 @@ func (e *executor) execute(ctx context.Context, ast *sql.Query) {
var buf []*Response var buf []*Response
var res []interface{} var res []interface{}
// Get the fibre context ID so that we can use
// it to clear or flush websocket notification
// changes linked to this context.
id := ctx.Value(ctxKeyId).(string)
// Ensure that the executor is added back into // Ensure that the executor is added back into
// the executor pool when the executor has // the executor pool when the executor has
// finished processing the request. // finished processing the request.
@ -74,7 +80,7 @@ func (e *executor) execute(ctx context.Context, ast *sql.Query) {
defer func() { defer func() {
if e.dbo.TX != nil { if e.dbo.TX != nil {
e.dbo.Cancel() e.dbo.Cancel()
clear() clear(id)
} }
}() }()
@ -145,18 +151,18 @@ func (e *executor) execute(ctx context.Context, ast *sql.Query) {
case *sql.CancelStatement: case *sql.CancelStatement:
err, buf = e.cancel(buf, err, e.send) err, buf = e.cancel(buf, err, e.send)
if err != nil { if err != nil {
clear() clear(id)
} else { } else {
clear() clear(id)
} }
trc.Finish() trc.Finish()
continue continue
case *sql.CommitStatement: case *sql.CommitStatement:
err, buf = e.commit(buf, err, e.send) err, buf = e.commit(buf, err, e.send)
if err != nil { if err != nil {
clear() clear(id)
} else { } else {
flush() flush(id)
} }
trc.Finish() trc.Finish()
continue continue
@ -262,6 +268,12 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf
} }
} }
// Get the fibre context ID so that we can use
// it to clear or flush websocket notification
// changes linked to this context.
id := ctx.Value(ctxKeyId).(string)
// Execute the defined statement, receiving the // Execute the defined statement, receiving the
// result set, and any errors which occured // result set, and any errors which occured
// while processing the query. // while processing the query.
@ -357,7 +369,7 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf
e.dbo.Cancel() e.dbo.Cancel()
e.dbo.Reset() e.dbo.Reset()
clear() clear(id)
default: default:
@ -379,7 +391,7 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf
if err != nil { if err != nil {
e.dbo.Cancel() e.dbo.Cancel()
clear() clear(id)
return return
} }
@ -389,15 +401,15 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf
if !trw { if !trw {
if err = e.dbo.Cancel(); err != nil { if err = e.dbo.Cancel(); err != nil {
clear() clear(id)
} else { } else {
clear() clear(id)
} }
} else { } else {
if err = e.dbo.Commit(); err != nil { if err = e.dbo.Commit(); err != nil {
clear() clear(id)
} else { } else {
flush() shift(id)
} }
} }
} }

View file

@ -39,7 +39,8 @@ func register(fib *fibre.Context, id string) func() {
sockets[id] = &socket{ sockets[id] = &socket{
fibre: fib, fibre: fib,
waits: make([]interface{}, 0), holds: make(map[string][]interface{}),
waits: make(map[string][]interface{}),
lives: make(map[string]*sql.LiveStatement), lives: make(map[string]*sql.LiveStatement),
} }

View file

@ -25,6 +25,12 @@ import (
// this table, and executes them in name order. // this table, and executes them in name order.
func (d *document) lives(ctx context.Context, when method) (err error) { func (d *document) lives(ctx context.Context, when method) (err error) {
// Get the ID of the current fibre
// connection so that we can check
// against the ID of live queries.
id := ctx.Value(ctxKeyId).(string)
// If this document has not changed // If this document has not changed
// then there is no need to update // then there is no need to update
// any registered live queries. // any registered live queries.
@ -54,6 +60,14 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
ctx = con.ctx(d.ns, d.db) ctx = con.ctx(d.ns, d.db)
// Check whether the change was made by
// the same connection as the live query,
// and if it is then don't notify changes.
if id == lv.FB {
continue
}
// Check whether this live query has the // Check whether this live query has the
// necessary permissions to view this // necessary permissions to view this
// document, or continue to the next query. // document, or continue to the next query.
@ -120,11 +134,11 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
switch when { switch when {
case _CREATE: case _CREATE:
con.queue(lv.ID, "CREATE", doc.Data()) con.queue(id, lv.ID, "CREATE", doc.Data())
case _UPDATE: case _UPDATE:
con.queue(lv.ID, "UPDATE", doc.Data()) con.queue(id, lv.ID, "UPDATE", doc.Data())
case _DELETE: case _DELETE:
con.queue(lv.ID, "DELETE", d.id) con.queue(id, lv.ID, "DELETE", d.id)
} }
} }

View file

@ -32,19 +32,26 @@ import (
type socket struct { type socket struct {
mutex sync.Mutex mutex sync.Mutex
fibre *fibre.Context fibre *fibre.Context
waits []interface{} holds map[string][]interface{}
waits map[string][]interface{}
lives map[string]*sql.LiveStatement lives map[string]*sql.LiveStatement
} }
func clear() { func clear(id string) {
for _, s := range sockets { for _, s := range sockets {
s.clear() s.clear(id)
} }
} }
func flush() { func shift(id string) {
for _, s := range sockets { for _, s := range sockets {
s.flush() s.shift(id)
}
}
func flush(id string) {
for _, s := range sockets {
s.flush(id)
} }
} }
@ -70,12 +77,12 @@ func (s *socket) ctx(ns, db string) (ctx context.Context) {
} }
func (s *socket) queue(query, action string, result interface{}) { func (s *socket) queue(id, query, action string, result interface{}) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.waits = append(s.waits, &Dispatch{ s.holds[id] = append(s.holds[id], &Dispatch{
Query: query, Query: query,
Action: action, Action: action,
Result: result, Result: result,
@ -83,18 +90,31 @@ func (s *socket) queue(query, action string, result interface{}) {
} }
func (s *socket) clear() (err error) { func (s *socket) clear(id string) (err error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.waits = nil s.holds[id] = nil
return return
} }
func (s *socket) flush() (err error) { func (s *socket) shift(id string) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.waits[id] = s.holds[id]
s.holds[id] = nil
return
}
func (s *socket) flush(id string) (err error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -103,7 +123,7 @@ func (s *socket) flush() (err error) {
// notifications for this socket // notifications for this socket
// then ignore this method call. // then ignore this method call.
if len(s.waits) == 0 { if len(s.waits[id]) == 0 {
return nil return nil
} }
@ -113,7 +133,7 @@ func (s *socket) flush() (err error) {
obj := &fibre.RPCNotification{ obj := &fibre.RPCNotification{
Method: "notify", Method: "notify",
Params: s.waits, Params: s.waits[id],
} }
// Check the websocket subprotocol // Check the websocket subprotocol
@ -137,7 +157,7 @@ func (s *socket) flush() (err error) {
// pending message notifications // pending message notifications
// for this socket when done. // for this socket when done.
s.waits = nil s.waits[id] = nil
return return