Prevent race conditions in live query socket notifications
This commit is contained in:
parent
672d298e7e
commit
a62797e7da
5 changed files with 42 additions and 55 deletions
6
db/db.go
6
db/db.go
|
@ -66,9 +66,11 @@ func Exit() (err error) {
|
||||||
|
|
||||||
log.WithPrefix("db").Infof("Gracefully shutting down database")
|
log.WithPrefix("db").Infof("Gracefully shutting down database")
|
||||||
|
|
||||||
for id, so := range sockets {
|
sockets.Range(func(key, val interface{}) bool {
|
||||||
|
id, so := key.(string), val.(*socket)
|
||||||
deregister(so.fibre, id)()
|
deregister(so.fibre, id)()
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
return db.Close()
|
return db.Close()
|
||||||
|
|
||||||
|
|
30
db/live.go
30
db/live.go
|
@ -23,25 +23,16 @@ import (
|
||||||
"github.com/abcum/surreal/sql"
|
"github.com/abcum/surreal/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
var locker sync.Mutex
|
var sockets sync.Map
|
||||||
|
|
||||||
var sockets map[string]*socket
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
sockets = make(map[string]*socket)
|
|
||||||
}
|
|
||||||
|
|
||||||
func register(fib *fibre.Context, id string) func() {
|
func register(fib *fibre.Context, id string) func() {
|
||||||
return func() {
|
return func() {
|
||||||
|
|
||||||
locker.Lock()
|
sockets.LoadOrStore(id, &socket{
|
||||||
defer locker.Unlock()
|
|
||||||
|
|
||||||
sockets[id] = &socket{
|
|
||||||
fibre: fib,
|
fibre: fib,
|
||||||
items: make(map[string][]interface{}),
|
items: make(map[string][]interface{}),
|
||||||
lives: make(map[string]*sql.LiveStatement),
|
lives: make(map[string]*sql.LiveStatement),
|
||||||
}
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,11 +40,8 @@ func register(fib *fibre.Context, id string) func() {
|
||||||
func deregister(fib *fibre.Context, id string) func() {
|
func deregister(fib *fibre.Context, id string) func() {
|
||||||
return func() {
|
return func() {
|
||||||
|
|
||||||
locker.Lock()
|
if sck, ok := sockets.Load(id); ok {
|
||||||
defer locker.Unlock()
|
sck.(*socket).deregister(id)
|
||||||
|
|
||||||
if sck, ok := sockets[id]; ok {
|
|
||||||
sck.deregister(id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -63,8 +51,8 @@ func (e *executor) executeLive(ctx context.Context, stm *sql.LiveStatement) (out
|
||||||
|
|
||||||
stm.FB = ctx.Value(ctxKeyId).(string)
|
stm.FB = ctx.Value(ctxKeyId).(string)
|
||||||
|
|
||||||
if sck, ok := sockets[stm.FB]; ok {
|
if sck, ok := sockets.Load(stm.FB); ok {
|
||||||
return sck.executeLive(e, ctx, stm)
|
return sck.(*socket).executeLive(e, ctx, stm)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &QueryError{}
|
return nil, &QueryError{}
|
||||||
|
@ -75,8 +63,8 @@ func (e *executor) executeKill(ctx context.Context, stm *sql.KillStatement) (out
|
||||||
|
|
||||||
stm.FB = ctx.Value(ctxKeyId).(string)
|
stm.FB = ctx.Value(ctxKeyId).(string)
|
||||||
|
|
||||||
if sck, ok := sockets[stm.FB]; ok {
|
if sck, ok := sockets.Load(stm.FB); ok {
|
||||||
return sck.executeKill(e, ctx, stm)
|
return sck.(*socket).executeKill(e, ctx, stm)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &QueryError{}
|
return nil, &QueryError{}
|
||||||
|
|
18
db/lives.go
18
db/lives.go
|
@ -51,13 +51,15 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
|
||||||
|
|
||||||
for _, lv := range lvs {
|
for _, lv := range lvs {
|
||||||
|
|
||||||
var ok bool
|
if sck, ok := sockets.Load(lv.FB); ok {
|
||||||
var con *socket
|
|
||||||
var out interface{}
|
|
||||||
|
|
||||||
if con, ok = sockets[lv.FB]; ok {
|
var out interface{}
|
||||||
|
|
||||||
ctx = con.ctx(d.ns, d.db)
|
// Create a new context for this socket
|
||||||
|
// which has the correct connection
|
||||||
|
// variables, and auth levels.
|
||||||
|
|
||||||
|
ctx = sck.(*socket).ctx(d.ns, d.db)
|
||||||
|
|
||||||
// Check whether the change was made by
|
// Check whether the change was made by
|
||||||
// the same connection as the live query,
|
// the same connection as the live query,
|
||||||
|
@ -111,14 +113,14 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
|
||||||
|
|
||||||
switch when {
|
switch when {
|
||||||
case _DELETE:
|
case _DELETE:
|
||||||
con.queue(id, lv.ID, "DELETE", d.id)
|
sck.(*socket).queue(id, lv.ID, "DELETE", d.id)
|
||||||
case _CREATE:
|
case _CREATE:
|
||||||
if out != nil {
|
if out != nil {
|
||||||
con.queue(id, lv.ID, "CREATE", out)
|
sck.(*socket).queue(id, lv.ID, "CREATE", out)
|
||||||
}
|
}
|
||||||
case _UPDATE:
|
case _UPDATE:
|
||||||
if out != nil {
|
if out != nil {
|
||||||
con.queue(id, lv.ID, "UPDATE", out)
|
sck.(*socket).queue(id, lv.ID, "UPDATE", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
39
db/socket.go
39
db/socket.go
|
@ -37,15 +37,21 @@ type socket struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func clear(id string) {
|
func clear(id string) {
|
||||||
for _, s := range sockets {
|
go func() {
|
||||||
s.clear(id)
|
sockets.Range(func(key, val interface{}) bool {
|
||||||
}
|
val.(*socket).clear(id)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func flush(id string) {
|
func flush(id string) {
|
||||||
for _, s := range sockets {
|
go func() {
|
||||||
s.flush(id)
|
sockets.Range(func(key, val interface{}) bool {
|
||||||
}
|
val.(*socket).flush(id)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *socket) ctx(ns, db string) (ctx context.Context) {
|
func (s *socket) ctx(ns, db string) (ctx context.Context) {
|
||||||
|
@ -119,22 +125,11 @@ func (s *socket) flush(id string) (err error) {
|
||||||
Params: s.items[id],
|
Params: s.items[id],
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the websocket subprotocol
|
// Notify the websocket connection
|
||||||
// and send the relevant message
|
// y sending an RPCNotification type
|
||||||
// type containing the notification.
|
// to the notify channel.
|
||||||
|
|
||||||
sock := s.fibre.Socket()
|
s.fibre.Socket().Notify(obj)
|
||||||
|
|
||||||
switch sock.Subprotocol() {
|
|
||||||
default:
|
|
||||||
err = sock.SendJSON(obj)
|
|
||||||
case "json":
|
|
||||||
err = sock.SendJSON(obj)
|
|
||||||
case "cbor":
|
|
||||||
err = sock.SendCBOR(obj)
|
|
||||||
case "pack":
|
|
||||||
err = sock.SendPACK(obj)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that we clear all the
|
// Make sure that we clear all the
|
||||||
// pending message notifications
|
// pending message notifications
|
||||||
|
@ -225,7 +220,7 @@ func (s *socket) check(e *executor, ctx context.Context, ns, db, tb string) (err
|
||||||
|
|
||||||
func (s *socket) deregister(id string) {
|
func (s *socket) deregister(id string) {
|
||||||
|
|
||||||
delete(sockets, id)
|
sockets.Delete(id)
|
||||||
|
|
||||||
txn, _ := db.Begin(context.Background(), true)
|
txn, _ := db.Begin(context.Background(), true)
|
||||||
|
|
||||||
|
|
4
glide.lock
generated
4
glide.lock
generated
|
@ -1,12 +1,12 @@
|
||||||
hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed
|
hash: c56e5bd935dd1933a6e7370fe3fc67ea26169ec91aa05c171543866c6c2490ed
|
||||||
updated: 2018-04-24T13:32:49.661918+01:00
|
updated: 2018-04-24T23:43:27.310737+01:00
|
||||||
imports:
|
imports:
|
||||||
- name: github.com/abcum/bump
|
- name: github.com/abcum/bump
|
||||||
version: 526934c541e071b5a330671c76434b9e32d55638
|
version: 526934c541e071b5a330671c76434b9e32d55638
|
||||||
- name: github.com/abcum/cork
|
- name: github.com/abcum/cork
|
||||||
version: c246208017d0b81f2e9a3fc2fb7a993c89153839
|
version: c246208017d0b81f2e9a3fc2fb7a993c89153839
|
||||||
- name: github.com/abcum/fibre
|
- name: github.com/abcum/fibre
|
||||||
version: 1b1947da964c0c0a244868279f9476df093eef34
|
version: 24b2157453a929f7a86616c415d01b94916b3ed5
|
||||||
subpackages:
|
subpackages:
|
||||||
- mw
|
- mw
|
||||||
- name: github.com/abcum/ptree
|
- name: github.com/abcum/ptree
|
||||||
|
|
Loading…
Reference in a new issue