Move to channel based mutex

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-28 20:32:05 +01:00
parent 99d050b238
commit f335d71aba
12 changed files with 64 additions and 95 deletions

View file

@ -18,7 +18,7 @@ import (
"context" "context"
) )
func vers(ctx context.Context) int { func vers(ctx context.Context) uint32 {
v := ctx.Value(ctxKeyDive) v := ctx.Value(ctxKeyDive)
@ -26,7 +26,7 @@ func vers(ctx context.Context) int {
case nil: case nil:
return 0 return 0
default: default:
return v.(int) return v.(uint32)
} }
} }
@ -37,12 +37,12 @@ func dive(ctx context.Context) context.Context {
switch v { switch v {
case nil: case nil:
return context.WithValue(ctx, ctxKeyDive, 1) return context.WithValue(ctx, ctxKeyDive, uint32(1))
default: default:
if v.(int) > maxRecursiveQueries { if v.(uint32) > maxRecursiveQueries {
panic(errRecursiveOverload) panic(errRecursiveOverload)
} }
return context.WithValue(ctx, ctxKeyDive, v.(int)+1) return context.WithValue(ctx, ctxKeyDive, v.(uint32)+1)
} }
} }

View file

@ -108,7 +108,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -107,7 +107,7 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -36,13 +36,10 @@ type document struct {
id *sql.Thing id *sql.Thing
key *keys.Thing key *keys.Thing
val kvs.KV val kvs.KV
lck bool
doc *data.Doc doc *data.Doc
initial *data.Doc initial *data.Doc
current *data.Doc current *data.Doc
locks struct {
r bool
w bool
}
store struct { store struct {
id int id int
tb bool tb bool
@ -71,9 +68,7 @@ func newDocument(i *iterator, key *keys.Thing, val kvs.KV, doc *data.Doc) (d *do
d.key = key d.key = key
d.val = val d.val = val
d.doc = doc d.doc = doc
d.lck = false
d.locks.r = false
d.locks.w = false
return return
@ -199,10 +194,10 @@ func (d *document) init(ctx context.Context) (err error) {
} }
func (d *document) wlock(ctx context.Context) (err error) { func (d *document) lock(ctx context.Context) (err error) {
if d.key != nil { if d.key != nil {
d.locks.w = true d.lck = true
d.i.e.lock.Lock(ctx, d.key) d.i.e.lock.Lock(ctx, d.key)
} }
@ -210,29 +205,13 @@ func (d *document) wlock(ctx context.Context) (err error) {
} }
func (d *document) rlock(ctx context.Context) (err error) {
if d.key != nil {
d.locks.r = true
d.i.e.lock.RLock(ctx, d.key)
}
return
}
func (d *document) ulock(ctx context.Context) (err error) { func (d *document) ulock(ctx context.Context) (err error) {
if d.key != nil && d.locks.w { if d.key != nil && d.lck {
d.locks.w = false d.lck = false
d.i.e.lock.Unlock(ctx, d.key) d.i.e.lock.Unlock(ctx, d.key)
} }
if d.key != nil && d.locks.r {
d.locks.r = false
d.i.e.lock.RUnlock(ctx, d.key)
}
return return
} }

View file

@ -86,7 +86,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -23,62 +23,58 @@ import (
type mutex struct { type mutex struct {
m sync.Map m sync.Map
l sync.Mutex
} }
type value struct { type value struct {
v int v uint32
r int64 q chan struct{}
w int64 l chan struct{}
l sync.RWMutex
} }
func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) { func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) {
m.l.Lock()
_, v := m.item(ctx, key) _, v := m.item(ctx, key)
if v.v < vers(ctx) {
m.l.Unlock() for {
select {
default:
if atomic.LoadUint32(&v.v) < vers(ctx) {
close(v.q)
panic(errRaceCondition) panic(errRaceCondition)
} }
atomic.AddInt64(&v.w, 1) case <-ctx.Done():
m.l.Unlock() return
v.l.Lock() case <-v.q:
return
case v.l <- struct{}{}:
atomic.StoreUint32(&v.v, vers(ctx))
return
}
} }
func (m *mutex) RLock(ctx context.Context, key fmt.Stringer) {
m.l.Lock()
_, v := m.item(ctx, key)
atomic.AddInt64(&v.r, 1)
m.l.Unlock()
v.l.RLock()
} }
func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) { func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) {
m.l.Lock()
defer m.l.Unlock() _, v := m.item(ctx, key)
k, v := m.item(ctx, key)
if w := atomic.LoadInt64(&v.w); w > 0 { select {
if w := atomic.AddInt64(&v.w, -1); w <= 0 { case <-ctx.Done():
m.m.Delete(k) return
} case <-v.q:
v.l.Unlock() return
} case <-v.l:
return
} }
func (m *mutex) RUnlock(ctx context.Context, key fmt.Stringer) {
m.l.Lock()
defer m.l.Unlock()
k, v := m.item(ctx, key)
if r := atomic.LoadInt64(&v.r); r > 0 {
if r := atomic.AddInt64(&v.r, -1); r <= 0 {
m.m.Delete(k)
}
v.l.RUnlock()
}
} }
func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) { func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) {
k := key.String() k := key.String()
v, _ := m.m.LoadOrStore(k, &value{v: vers(ctx)}) v, _ := m.m.LoadOrStore(k, &value{
v: vers(ctx),
q: make(chan struct{}),
l: make(chan struct{}, 1),
})
return k, v.(*value) return k, v.(*value)
} }

View file

@ -100,13 +100,11 @@ func TestMutex(t *testing.T) {
m := new(mutex) m := new(mutex)
ctx := context.Background() ctx := context.Background()
for i := 0; i < n; i++ {
m.Lock(ctx, new(stringer)) m.Lock(ctx, new(stringer))
ctx = dive(ctx) ctx = dive(ctx)
So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic) So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic)
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
}
So(nil, ShouldBeNil) So(nil, ShouldBeNil)
@ -191,7 +189,7 @@ func TestMutex(t *testing.T) {
}) })
Convey("Inability to update the same document in a SELECT subquery", t, func() { Convey("Ability to update the same document in a SELECT subquery", t, func() {
setupDB(20) setupDB(20)
@ -207,12 +205,12 @@ func TestMutex(t *testing.T) {
So(res, ShouldHaveLength, 4) So(res, ShouldHaveLength, 4)
So(res[1].Status, ShouldEqual, "OK") So(res[1].Status, ShouldEqual, "OK")
So(res[1].Result, ShouldHaveLength, 1) So(res[1].Result, ShouldHaveLength, 1)
So(res[2].Status, ShouldEqual, "ERR") So(res[2].Status, ShouldEqual, "OK")
So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively") So(res[2].Result, ShouldHaveLength, 1)
So(res[3].Status, ShouldEqual, "OK") So(res[3].Status, ShouldEqual, "OK")
So(res[3].Result, ShouldHaveLength, 1) So(res[3].Result, ShouldHaveLength, 1)
So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil) So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil)
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil) So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true)
}) })

View file

@ -100,7 +100,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -147,10 +147,6 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int
return nil, err return nil, err
} }
if err = d.rlock(ctx); err != nil {
return nil, err
}
if err = d.setup(ctx); err != nil { if err = d.setup(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -107,7 +107,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -86,7 +86,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int
return nil, err return nil, err
} }
if err = d.wlock(ctx); err != nil { if err = d.lock(ctx); err != nil {
return nil, err return nil, err
} }

View file

@ -75,7 +75,7 @@ var (
// maxRecursiveQueries specifies how many queries will be // maxRecursiveQueries specifies how many queries will be
// processed recursively before the query is cancelled. // processed recursively before the query is cancelled.
maxRecursiveQueries = 16 maxRecursiveQueries = uint32(16)
// queryIdentFailed occurs when a permission query asks // queryIdentFailed occurs when a permission query asks
// for a field, meaning a document has to be fetched. // for a field, meaning a document has to be fetched.