Prevent race conditions in live query socket notifications

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-25 00:01:29 +01:00
parent 672d298e7e
commit a62797e7da
5 changed files with 42 additions and 55 deletions

View file

@ -66,9 +66,11 @@ func Exit() (err error) {
log.WithPrefix("db").Infof("Gracefully shutting down database") log.WithPrefix("db").Infof("Gracefully shutting down database")
for id, so := range sockets { sockets.Range(func(key, val interface{}) bool {
id, so := key.(string), val.(*socket)
deregister(so.fibre, id)() deregister(so.fibre, id)()
} return true
})
return db.Close() return db.Close()

View file

@ -23,25 +23,16 @@ import (
"github.com/abcum/surreal/sql" "github.com/abcum/surreal/sql"
) )
var locker sync.Mutex var sockets sync.Map
var sockets map[string]*socket
func init() {
sockets = make(map[string]*socket)
}
func register(fib *fibre.Context, id string) func() { func register(fib *fibre.Context, id string) func() {
return func() { return func() {
locker.Lock() sockets.LoadOrStore(id, &socket{
defer locker.Unlock()
sockets[id] = &socket{
fibre: fib, fibre: fib,
items: make(map[string][]interface{}), items: make(map[string][]interface{}),
lives: make(map[string]*sql.LiveStatement), lives: make(map[string]*sql.LiveStatement),
} })
} }
} }
@ -49,11 +40,8 @@ func register(fib *fibre.Context, id string) func() {
func deregister(fib *fibre.Context, id string) func() { func deregister(fib *fibre.Context, id string) func() {
return func() { return func() {
locker.Lock() if sck, ok := sockets.Load(id); ok {
defer locker.Unlock() sck.(*socket).deregister(id)
if sck, ok := sockets[id]; ok {
sck.deregister(id)
} }
} }
@ -63,8 +51,8 @@ func (e *executor) executeLive(ctx context.Context, stm *sql.LiveStatement) (out
stm.FB = ctx.Value(ctxKeyId).(string) stm.FB = ctx.Value(ctxKeyId).(string)
if sck, ok := sockets[stm.FB]; ok { if sck, ok := sockets.Load(stm.FB); ok {
return sck.executeLive(e, ctx, stm) return sck.(*socket).executeLive(e, ctx, stm)
} }
return nil, &QueryError{} return nil, &QueryError{}
@ -75,8 +63,8 @@ func (e *executor) executeKill(ctx context.Context, stm *sql.KillStatement) (out
stm.FB = ctx.Value(ctxKeyId).(string) stm.FB = ctx.Value(ctxKeyId).(string)
if sck, ok := sockets[stm.FB]; ok { if sck, ok := sockets.Load(stm.FB); ok {
return sck.executeKill(e, ctx, stm) return sck.(*socket).executeKill(e, ctx, stm)
} }
return nil, &QueryError{} return nil, &QueryError{}

View file

@ -51,13 +51,15 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
for _, lv := range lvs { for _, lv := range lvs {
var ok bool if sck, ok := sockets.Load(lv.FB); ok {
var con *socket
var out interface{}
if con, ok = sockets[lv.FB]; ok { var out interface{}
ctx = con.ctx(d.ns, d.db) // Create a new context for this socket
// which has the correct connection
// variables, and auth levels.
ctx = sck.(*socket).ctx(d.ns, d.db)
// Check whether the change was made by // Check whether the change was made by
// the same connection as the live query, // the same connection as the live query,
@ -111,14 +113,14 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
switch when { switch when {
case _DELETE: case _DELETE:
con.queue(id, lv.ID, "DELETE", d.id) sck.(*socket).queue(id, lv.ID, "DELETE", d.id)
case _CREATE: case _CREATE:
if out != nil { if out != nil {
con.queue(id, lv.ID, "CREATE", out) sck.(*socket).queue(id, lv.ID, "CREATE", out)
} }
case _UPDATE: case _UPDATE:
if out != nil { if out != nil {
con.queue(id, lv.ID, "UPDATE", out) sck.(*socket).queue(id, lv.ID, "UPDATE", out)
} }
} }

View file

@ -37,15 +37,21 @@ type socket struct {
} }
func clear(id string) { func clear(id string) {
for _, s := range sockets { go func() {
s.clear(id) sockets.Range(func(key, val interface{}) bool {
} val.(*socket).clear(id)
return true
})
}()
} }
func flush(id string) { func flush(id string) {
for _, s := range sockets { go func() {
s.flush(id) sockets.Range(func(key, val interface{}) bool {
} val.(*socket).flush(id)
return true
})
}()
} }
func (s *socket) ctx(ns, db string) (ctx context.Context) { func (s *socket) ctx(ns, db string) (ctx context.Context) {
@ -119,22 +125,11 @@ func (s *socket) flush(id string) (err error) {
Params: s.items[id], Params: s.items[id],
} }
// Check the websocket subprotocol // Notify the websocket connection
// and send the relevant message // y sending an RPCNotification type
// type containing the notification. // to the notify channel.
sock := s.fibre.Socket() s.fibre.Socket().Notify(obj)
switch sock.Subprotocol() {
default:
err = sock.SendJSON(obj)
case "json":
err = sock.SendJSON(obj)
case "cbor":
err = sock.SendCBOR(obj)
case "pack":
err = sock.SendPACK(obj)
}
// Make sure that we clear all the // Make sure that we clear all the
// pending message notifications // pending message notifications
@ -225,7 +220,7 @@ func (s *socket) check(e *executor, ctx context.Context, ns, db, tb string) (err
func (s *socket) deregister(id string) { func (s *socket) deregister(id string) {
delete(sockets, id) sockets.Delete(id)
txn, _ := db.Begin(context.Background(), true) txn, _ := db.Begin(context.Background(), true)

4
glide.lock generated
View file

@ -1,12 +1,12 @@
hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed
updated: 2018-04-24T13:32:49.661918+01:00 updated: 2018-04-24T23:43:27.310737+01:00
imports: imports:
- name: github.com/abcum/bump - name: github.com/abcum/bump
version: 526934c541e071b5a330671c76434b9e32d55638 version: 526934c541e071b5a330671c76434b9e32d55638
- name: github.com/abcum/cork - name: github.com/abcum/cork
version: c246208017d0b81f2e9a3fc2fb7a993c89153839 version: c246208017d0b81f2e9a3fc2fb7a993c89153839
- name: github.com/abcum/fibre - name: github.com/abcum/fibre
version: 1b1947da964c0c0a244868279f9476df093eef34 version: 24b2157453a929f7a86616c415d01b94916b3ed5
subpackages: subpackages:
- mw - mw
- name: github.com/abcum/ptree - name: github.com/abcum/ptree