diff --git a/db/context.go b/db/context.go index 9f2da37c..720bd3bd 100644 --- a/db/context.go +++ b/db/context.go @@ -18,7 +18,7 @@ import ( "context" ) -func vers(ctx context.Context) int { +func vers(ctx context.Context) uint32 { v := ctx.Value(ctxKeyDive) @@ -26,7 +26,7 @@ func vers(ctx context.Context) int { case nil: return 0 default: - return v.(int) + return v.(uint32) } } @@ -37,12 +37,12 @@ func dive(ctx context.Context) context.Context { switch v { case nil: - return context.WithValue(ctx, ctxKeyDive, 1) + return context.WithValue(ctx, ctxKeyDive, uint32(1)) default: - if v.(int) > maxRecursiveQueries { + if v.(uint32) > maxRecursiveQueries { panic(errRecursiveOverload) } - return context.WithValue(ctx, ctxKeyDive, v.(int)+1) + return context.WithValue(ctx, ctxKeyDive, v.(uint32)+1) } } diff --git a/db/create.go b/db/create.go index fae7b153..94462b56 100644 --- a/db/create.go +++ b/db/create.go @@ -108,7 +108,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/delete.go b/db/delete.go index 12646701..1341d0b0 100644 --- a/db/delete.go +++ b/db/delete.go @@ -107,7 +107,7 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/document.go b/db/document.go index 4f6f94aa..e96875a9 100644 --- a/db/document.go +++ b/db/document.go @@ -36,14 +36,11 @@ type document struct { id *sql.Thing key *keys.Thing val kvs.KV + lck bool doc *data.Doc initial *data.Doc current *data.Doc - locks struct { - r bool - w bool - } - store struct { + store struct { id int tb bool ev bool @@ -71,9 +68,7 @@ func newDocument(i *iterator, key *keys.Thing, val kvs.KV, doc *data.Doc) (d *do d.key = key d.val = val d.doc = doc - - d.locks.r = false - d.locks.w = false + d.lck = false return @@ -199,10 +194,10 @@ func (d *document) init(ctx context.Context) (err error) { } -func (d *document) wlock(ctx context.Context) (err error) { +func (d *document) lock(ctx context.Context) (err error) { if d.key != nil { - d.locks.w = true + d.lck = true d.i.e.lock.Lock(ctx, d.key) } @@ -210,29 +205,13 @@ func (d *document) wlock(ctx context.Context) (err error) { } -func (d *document) rlock(ctx context.Context) (err error) { - - if d.key != nil { - d.locks.r = true - d.i.e.lock.RLock(ctx, d.key) - } - - return - -} - func (d *document) ulock(ctx context.Context) (err error) { - if d.key != nil && d.locks.w { - d.locks.w = false + if d.key != nil && d.lck { + d.lck = false d.i.e.lock.Unlock(ctx, d.key) } - if d.key != nil && d.locks.r { - d.locks.r = false - d.i.e.lock.RUnlock(ctx, d.key) - } - return } diff --git a/db/insert.go b/db/insert.go index d05a1625..cfc64d72 100644 --- a/db/insert.go +++ b/db/insert.go @@ -86,7 +86,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/mutex.go b/db/mutex.go index abcab734..25894cb3 100644 --- a/db/mutex.go +++ b/db/mutex.go @@ -23,62 +23,58 @@ import ( type mutex struct { m sync.Map - l sync.Mutex } type value struct { - v int - r int64 - w int64 - l sync.RWMutex + v uint32 + q chan struct{} + l chan struct{} } func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) { - m.l.Lock() - _, v := m.item(ctx, key) - if v.v < vers(ctx) { - m.l.Unlock() - panic(errRaceCondition) - } - atomic.AddInt64(&v.w, 1) - m.l.Unlock() - v.l.Lock() -} -func (m *mutex) RLock(ctx context.Context, key fmt.Stringer) { - m.l.Lock() _, v := m.item(ctx, key) - atomic.AddInt64(&v.r, 1) - m.l.Unlock() - v.l.RLock() + + for { + select { + default: + if atomic.LoadUint32(&v.v) < vers(ctx) { + close(v.q) + panic(errRaceCondition) + } + case <-ctx.Done(): + return + case <-v.q: + return + case v.l <- struct{}{}: + atomic.StoreUint32(&v.v, vers(ctx)) + return + } + } + } func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) { - m.l.Lock() - defer m.l.Unlock() - k, v := m.item(ctx, key) - if w := atomic.LoadInt64(&v.w); w > 0 { - if w := atomic.AddInt64(&v.w, -1); w <= 0 { - m.m.Delete(k) - } - v.l.Unlock() - } -} -func (m *mutex) RUnlock(ctx context.Context, key fmt.Stringer) { - m.l.Lock() - defer m.l.Unlock() - k, v := m.item(ctx, key) - if r := atomic.LoadInt64(&v.r); r > 0 { - if r := atomic.AddInt64(&v.r, -1); r <= 0 { - m.m.Delete(k) - } - v.l.RUnlock() + _, v := m.item(ctx, key) + + select { + case <-ctx.Done(): + return + case <-v.q: + return + case <-v.l: + return } + } func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) { k := key.String() - v, _ := m.m.LoadOrStore(k, &value{v: vers(ctx)}) + v, _ := m.m.LoadOrStore(k, &value{ + v: vers(ctx), + q: make(chan struct{}), + l: make(chan struct{}, 1), + }) return k, v.(*value) } diff --git a/db/mutex_test.go b/db/mutex_test.go index 5cf9513f..39e7cf64 100644 --- a/db/mutex_test.go +++ b/db/mutex_test.go @@ -100,13 +100,11 @@ func TestMutex(t *testing.T) { m := new(mutex) ctx := context.Background() - for i := 0; i < n; i++ { - m.Lock(ctx, new(stringer)) - ctx = dive(ctx) - So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic) - So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) - So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) - } + m.Lock(ctx, new(stringer)) + ctx = dive(ctx) + So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic) + So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) + So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) So(nil, ShouldBeNil) @@ -191,7 +189,7 @@ func TestMutex(t *testing.T) { }) - Convey("Inability to update the same document in a SELECT subquery", t, func() { + Convey("Ability to update the same document in a SELECT subquery", t, func() { setupDB(20) @@ -207,12 +205,12 @@ func TestMutex(t *testing.T) { So(res, ShouldHaveLength, 4) So(res[1].Status, ShouldEqual, "OK") So(res[1].Result, ShouldHaveLength, 1) - So(res[2].Status, ShouldEqual, "ERR") - So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively") + So(res[2].Status, ShouldEqual, "OK") + So(res[2].Result, ShouldHaveLength, 1) So(res[3].Status, ShouldEqual, "OK") So(res[3].Result, ShouldHaveLength, 1) So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil) - So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true) }) diff --git a/db/relate.go b/db/relate.go index c7575c76..3bb0d683 100644 --- a/db/relate.go +++ b/db/relate.go @@ -100,7 +100,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/select.go b/db/select.go index 0c28dd0e..5ffcaf90 100644 --- a/db/select.go +++ b/db/select.go @@ -147,10 +147,6 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int return nil, err } - if err = d.rlock(ctx); err != nil { - return nil, err - } - if err = d.setup(ctx); err != nil { return nil, err } diff --git a/db/update.go b/db/update.go index d747d30b..6e716ca0 100644 --- a/db/update.go +++ b/db/update.go @@ -107,7 +107,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/upsert.go b/db/upsert.go index c30cb47b..88454448 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -86,7 +86,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int return nil, err } - if err = d.wlock(ctx); err != nil { + if err = d.lock(ctx); err != nil { return nil, err } diff --git a/db/vars.go b/db/vars.go index 966fcf06..9c2bd732 100644 --- a/db/vars.go +++ b/db/vars.go @@ -75,7 +75,7 @@ var ( // maxRecursiveQueries specifies how many queries will be // processed recursively before the query is cancelled. - maxRecursiveQueries = 16 + maxRecursiveQueries = uint32(16) // queryIdentFailed occurs when a permission query asks // for a field, meaning a document has to be fetched.