Prevent race conditions in live query socket notifications

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-25 00:01:29 +01:00
parent 672d298e7e
commit a62797e7da
5 changed files with 42 additions and 55 deletions

View file

@ -66,9 +66,11 @@ func Exit() (err error) {
log.WithPrefix("db").Infof("Gracefully shutting down database")
for id, so := range sockets {
sockets.Range(func(key, val interface{}) bool {
id, so := key.(string), val.(*socket)
deregister(so.fibre, id)()
}
return true
})
return db.Close()

View file

@ -23,25 +23,16 @@ import (
"github.com/abcum/surreal/sql"
)
var locker sync.Mutex
var sockets map[string]*socket
func init() {
sockets = make(map[string]*socket)
}
var sockets sync.Map
func register(fib *fibre.Context, id string) func() {
return func() {
locker.Lock()
defer locker.Unlock()
sockets[id] = &socket{
sockets.LoadOrStore(id, &socket{
fibre: fib,
items: make(map[string][]interface{}),
lives: make(map[string]*sql.LiveStatement),
}
})
}
}
@ -49,11 +40,8 @@ func register(fib *fibre.Context, id string) func() {
func deregister(fib *fibre.Context, id string) func() {
return func() {
locker.Lock()
defer locker.Unlock()
if sck, ok := sockets[id]; ok {
sck.deregister(id)
if sck, ok := sockets.Load(id); ok {
sck.(*socket).deregister(id)
}
}
@ -63,8 +51,8 @@ func (e *executor) executeLive(ctx context.Context, stm *sql.LiveStatement) (out
stm.FB = ctx.Value(ctxKeyId).(string)
if sck, ok := sockets[stm.FB]; ok {
return sck.executeLive(e, ctx, stm)
if sck, ok := sockets.Load(stm.FB); ok {
return sck.(*socket).executeLive(e, ctx, stm)
}
return nil, &QueryError{}
@ -75,8 +63,8 @@ func (e *executor) executeKill(ctx context.Context, stm *sql.KillStatement) (out
stm.FB = ctx.Value(ctxKeyId).(string)
if sck, ok := sockets[stm.FB]; ok {
return sck.executeKill(e, ctx, stm)
if sck, ok := sockets.Load(stm.FB); ok {
return sck.(*socket).executeKill(e, ctx, stm)
}
return nil, &QueryError{}

View file

@ -51,13 +51,15 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
for _, lv := range lvs {
var ok bool
var con *socket
var out interface{}
if sck, ok := sockets.Load(lv.FB); ok {
if con, ok = sockets[lv.FB]; ok {
var out interface{}
ctx = con.ctx(d.ns, d.db)
// Create a new context for this socket
// which has the correct connection
// variables, and auth levels.
ctx = sck.(*socket).ctx(d.ns, d.db)
// Check whether the change was made by
// the same connection as the live query,
@ -111,14 +113,14 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
switch when {
case _DELETE:
con.queue(id, lv.ID, "DELETE", d.id)
sck.(*socket).queue(id, lv.ID, "DELETE", d.id)
case _CREATE:
if out != nil {
con.queue(id, lv.ID, "CREATE", out)
sck.(*socket).queue(id, lv.ID, "CREATE", out)
}
case _UPDATE:
if out != nil {
con.queue(id, lv.ID, "UPDATE", out)
sck.(*socket).queue(id, lv.ID, "UPDATE", out)
}
}

View file

@ -37,15 +37,21 @@ type socket struct {
}
func clear(id string) {
for _, s := range sockets {
s.clear(id)
}
go func() {
sockets.Range(func(key, val interface{}) bool {
val.(*socket).clear(id)
return true
})
}()
}
func flush(id string) {
for _, s := range sockets {
s.flush(id)
}
go func() {
sockets.Range(func(key, val interface{}) bool {
val.(*socket).flush(id)
return true
})
}()
}
func (s *socket) ctx(ns, db string) (ctx context.Context) {
@ -119,22 +125,11 @@ func (s *socket) flush(id string) (err error) {
Params: s.items[id],
}
// Check the websocket subprotocol
// and send the relevant message
// type containing the notification.
// Notify the websocket connection
// y sending an RPCNotification type
// to the notify channel.
sock := s.fibre.Socket()
switch sock.Subprotocol() {
default:
err = sock.SendJSON(obj)
case "json":
err = sock.SendJSON(obj)
case "cbor":
err = sock.SendCBOR(obj)
case "pack":
err = sock.SendPACK(obj)
}
s.fibre.Socket().Notify(obj)
// Make sure that we clear all the
// pending message notifications
@ -225,7 +220,7 @@ func (s *socket) check(e *executor, ctx context.Context, ns, db, tb string) (err
func (s *socket) deregister(id string) {
delete(sockets, id)
sockets.Delete(id)
txn, _ := db.Begin(context.Background(), true)

4
glide.lock generated
View file

@ -1,12 +1,12 @@
hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed
updated: 2018-04-24T13:32:49.661918+01:00
updated: 2018-04-24T23:43:27.310737+01:00
imports:
- name: github.com/abcum/bump
version: 526934c541e071b5a330671c76434b9e32d55638
- name: github.com/abcum/cork
version: c246208017d0b81f2e9a3fc2fb7a993c89153839
- name: github.com/abcum/fibre
version: 1b1947da964c0c0a244868279f9476df093eef34
version: 24b2157453a929f7a86616c415d01b94916b3ed5
subpackages:
- mw
- name: github.com/abcum/ptree