From a62797e7dab87339452c9c05215e44c06f9ebef2 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Wed, 25 Apr 2018 00:01:29 +0100 Subject: [PATCH] Prevent race conditions in live query socket notifications --- db/db.go | 6 ++++-- db/live.go | 30 +++++++++--------------------- db/lives.go | 18 ++++++++++-------- db/socket.go | 39 +++++++++++++++++---------------------- glide.lock | 4 ++-- 5 files changed, 42 insertions(+), 55 deletions(-) diff --git a/db/db.go b/db/db.go index c8e7a2d7..d7e25f4e 100644 --- a/db/db.go +++ b/db/db.go @@ -66,9 +66,11 @@ func Exit() (err error) { log.WithPrefix("db").Infof("Gracefully shutting down database") - for id, so := range sockets { + sockets.Range(func(key, val interface{}) bool { + id, so := key.(string), val.(*socket) deregister(so.fibre, id)() - } + return true + }) return db.Close() diff --git a/db/live.go b/db/live.go index 87714e12..b71cb2d4 100644 --- a/db/live.go +++ b/db/live.go @@ -23,25 +23,16 @@ import ( "github.com/abcum/surreal/sql" ) -var locker sync.Mutex - -var sockets map[string]*socket - -func init() { - sockets = make(map[string]*socket) -} +var sockets sync.Map func register(fib *fibre.Context, id string) func() { return func() { - locker.Lock() - defer locker.Unlock() - - sockets[id] = &socket{ + sockets.LoadOrStore(id, &socket{ fibre: fib, items: make(map[string][]interface{}), lives: make(map[string]*sql.LiveStatement), - } + }) } } @@ -49,11 +40,8 @@ func register(fib *fibre.Context, id string) func() { func deregister(fib *fibre.Context, id string) func() { return func() { - locker.Lock() - defer locker.Unlock() - - if sck, ok := sockets[id]; ok { - sck.deregister(id) + if sck, ok := sockets.Load(id); ok { + sck.(*socket).deregister(id) } } @@ -63,8 +51,8 @@ func (e *executor) executeLive(ctx context.Context, stm *sql.LiveStatement) (out stm.FB = ctx.Value(ctxKeyId).(string) - if sck, ok := sockets[stm.FB]; ok { - return sck.executeLive(e, ctx, stm) + if sck, ok := sockets.Load(stm.FB); ok { + return sck.(*socket).executeLive(e, ctx, stm) } return nil, &QueryError{} @@ -75,8 +63,8 @@ func (e *executor) executeKill(ctx context.Context, stm *sql.KillStatement) (out stm.FB = ctx.Value(ctxKeyId).(string) - if sck, ok := sockets[stm.FB]; ok { - return sck.executeKill(e, ctx, stm) + if sck, ok := sockets.Load(stm.FB); ok { + return sck.(*socket).executeKill(e, ctx, stm) } return nil, &QueryError{} diff --git a/db/lives.go b/db/lives.go index 1c7e0418..e9c0455a 100644 --- a/db/lives.go +++ b/db/lives.go @@ -51,13 +51,15 @@ func (d *document) lives(ctx context.Context, when method) (err error) { for _, lv := range lvs { - var ok bool - var con *socket - var out interface{} + if sck, ok := sockets.Load(lv.FB); ok { - if con, ok = sockets[lv.FB]; ok { + var out interface{} - ctx = con.ctx(d.ns, d.db) + // Create a new context for this socket + // which has the correct connection + // variables, and auth levels. + + ctx = sck.(*socket).ctx(d.ns, d.db) // Check whether the change was made by // the same connection as the live query, @@ -111,14 +113,14 @@ func (d *document) lives(ctx context.Context, when method) (err error) { switch when { case _DELETE: - con.queue(id, lv.ID, "DELETE", d.id) + sck.(*socket).queue(id, lv.ID, "DELETE", d.id) case _CREATE: if out != nil { - con.queue(id, lv.ID, "CREATE", out) + sck.(*socket).queue(id, lv.ID, "CREATE", out) } case _UPDATE: if out != nil { - con.queue(id, lv.ID, "UPDATE", out) + sck.(*socket).queue(id, lv.ID, "UPDATE", out) } } diff --git a/db/socket.go b/db/socket.go index afef30b4..eed5b8c6 100644 --- a/db/socket.go +++ b/db/socket.go @@ -37,15 +37,21 @@ type socket struct { } func clear(id string) { - for _, s := range sockets { - s.clear(id) - } + go func() { + sockets.Range(func(key, val interface{}) bool { + val.(*socket).clear(id) + return true + }) + }() } func flush(id string) { - for _, s := range sockets { - s.flush(id) - } + go func() { + sockets.Range(func(key, val interface{}) bool { + val.(*socket).flush(id) + return true + }) + }() } func (s *socket) ctx(ns, db string) (ctx context.Context) { @@ -119,22 +125,11 @@ func (s *socket) flush(id string) (err error) { Params: s.items[id], } - // Check the websocket subprotocol - // and send the relevant message - // type containing the notification. + // Notify the websocket connection + // y sending an RPCNotification type + // to the notify channel. - sock := s.fibre.Socket() - - switch sock.Subprotocol() { - default: - err = sock.SendJSON(obj) - case "json": - err = sock.SendJSON(obj) - case "cbor": - err = sock.SendCBOR(obj) - case "pack": - err = sock.SendPACK(obj) - } + s.fibre.Socket().Notify(obj) // Make sure that we clear all the // pending message notifications @@ -225,7 +220,7 @@ func (s *socket) check(e *executor, ctx context.Context, ns, db, tb string) (err func (s *socket) deregister(id string) { - delete(sockets, id) + sockets.Delete(id) txn, _ := db.Begin(context.Background(), true) diff --git a/glide.lock b/glide.lock index d60c9b3e..e0207773 100644 --- a/glide.lock +++ b/glide.lock @@ -1,12 +1,12 @@ hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed -updated: 2018-04-24T13:32:49.661918+01:00 +updated: 2018-04-24T23:43:27.310737+01:00 imports: - name: github.com/abcum/bump version: 526934c541e071b5a330671c76434b9e32d55638 - name: github.com/abcum/cork version: c246208017d0b81f2e9a3fc2fb7a993c89153839 - name: github.com/abcum/fibre - version: 1b1947da964c0c0a244868279f9476df093eef34 + version: 24b2157453a929f7a86616c415d01b94916b3ed5 subpackages: - mw - name: github.com/abcum/ptree