Move to channel based mutex
This commit is contained in:
parent
99d050b238
commit
f335d71aba
12 changed files with 64 additions and 95 deletions
|
@ -18,7 +18,7 @@ import (
|
|||
"context"
|
||||
)
|
||||
|
||||
func vers(ctx context.Context) int {
|
||||
func vers(ctx context.Context) uint32 {
|
||||
|
||||
v := ctx.Value(ctxKeyDive)
|
||||
|
||||
|
@ -26,7 +26,7 @@ func vers(ctx context.Context) int {
|
|||
case nil:
|
||||
return 0
|
||||
default:
|
||||
return v.(int)
|
||||
return v.(uint32)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -37,12 +37,12 @@ func dive(ctx context.Context) context.Context {
|
|||
|
||||
switch v {
|
||||
case nil:
|
||||
return context.WithValue(ctx, ctxKeyDive, 1)
|
||||
return context.WithValue(ctx, ctxKeyDive, uint32(1))
|
||||
default:
|
||||
if v.(int) > maxRecursiveQueries {
|
||||
if v.(uint32) > maxRecursiveQueries {
|
||||
panic(errRecursiveOverload)
|
||||
}
|
||||
return context.WithValue(ctx, ctxKeyDive, v.(int)+1)
|
||||
return context.WithValue(ctx, ctxKeyDive, v.(uint32)+1)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -36,13 +36,10 @@ type document struct {
|
|||
id *sql.Thing
|
||||
key *keys.Thing
|
||||
val kvs.KV
|
||||
lck bool
|
||||
doc *data.Doc
|
||||
initial *data.Doc
|
||||
current *data.Doc
|
||||
locks struct {
|
||||
r bool
|
||||
w bool
|
||||
}
|
||||
store struct {
|
||||
id int
|
||||
tb bool
|
||||
|
@ -71,9 +68,7 @@ func newDocument(i *iterator, key *keys.Thing, val kvs.KV, doc *data.Doc) (d *do
|
|||
d.key = key
|
||||
d.val = val
|
||||
d.doc = doc
|
||||
|
||||
d.locks.r = false
|
||||
d.locks.w = false
|
||||
d.lck = false
|
||||
|
||||
return
|
||||
|
||||
|
@ -199,10 +194,10 @@ func (d *document) init(ctx context.Context) (err error) {
|
|||
|
||||
}
|
||||
|
||||
func (d *document) wlock(ctx context.Context) (err error) {
|
||||
func (d *document) lock(ctx context.Context) (err error) {
|
||||
|
||||
if d.key != nil {
|
||||
d.locks.w = true
|
||||
d.lck = true
|
||||
d.i.e.lock.Lock(ctx, d.key)
|
||||
}
|
||||
|
||||
|
@ -210,29 +205,13 @@ func (d *document) wlock(ctx context.Context) (err error) {
|
|||
|
||||
}
|
||||
|
||||
func (d *document) rlock(ctx context.Context) (err error) {
|
||||
|
||||
if d.key != nil {
|
||||
d.locks.r = true
|
||||
d.i.e.lock.RLock(ctx, d.key)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (d *document) ulock(ctx context.Context) (err error) {
|
||||
|
||||
if d.key != nil && d.locks.w {
|
||||
d.locks.w = false
|
||||
if d.key != nil && d.lck {
|
||||
d.lck = false
|
||||
d.i.e.lock.Unlock(ctx, d.key)
|
||||
}
|
||||
|
||||
if d.key != nil && d.locks.r {
|
||||
d.locks.r = false
|
||||
d.i.e.lock.RUnlock(ctx, d.key)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
70
db/mutex.go
70
db/mutex.go
|
@ -23,62 +23,58 @@ import (
|
|||
|
||||
type mutex struct {
|
||||
m sync.Map
|
||||
l sync.Mutex
|
||||
}
|
||||
|
||||
type value struct {
|
||||
v int
|
||||
r int64
|
||||
w int64
|
||||
l sync.RWMutex
|
||||
v uint32
|
||||
q chan struct{}
|
||||
l chan struct{}
|
||||
}
|
||||
|
||||
func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) {
|
||||
m.l.Lock()
|
||||
|
||||
_, v := m.item(ctx, key)
|
||||
if v.v < vers(ctx) {
|
||||
m.l.Unlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
if atomic.LoadUint32(&v.v) < vers(ctx) {
|
||||
close(v.q)
|
||||
panic(errRaceCondition)
|
||||
}
|
||||
atomic.AddInt64(&v.w, 1)
|
||||
m.l.Unlock()
|
||||
v.l.Lock()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-v.q:
|
||||
return
|
||||
case v.l <- struct{}{}:
|
||||
atomic.StoreUint32(&v.v, vers(ctx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mutex) RLock(ctx context.Context, key fmt.Stringer) {
|
||||
m.l.Lock()
|
||||
_, v := m.item(ctx, key)
|
||||
atomic.AddInt64(&v.r, 1)
|
||||
m.l.Unlock()
|
||||
v.l.RLock()
|
||||
}
|
||||
|
||||
func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
k, v := m.item(ctx, key)
|
||||
if w := atomic.LoadInt64(&v.w); w > 0 {
|
||||
if w := atomic.AddInt64(&v.w, -1); w <= 0 {
|
||||
m.m.Delete(k)
|
||||
}
|
||||
v.l.Unlock()
|
||||
}
|
||||
|
||||
_, v := m.item(ctx, key)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-v.q:
|
||||
return
|
||||
case <-v.l:
|
||||
return
|
||||
}
|
||||
|
||||
func (m *mutex) RUnlock(ctx context.Context, key fmt.Stringer) {
|
||||
m.l.Lock()
|
||||
defer m.l.Unlock()
|
||||
k, v := m.item(ctx, key)
|
||||
if r := atomic.LoadInt64(&v.r); r > 0 {
|
||||
if r := atomic.AddInt64(&v.r, -1); r <= 0 {
|
||||
m.m.Delete(k)
|
||||
}
|
||||
v.l.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) {
|
||||
k := key.String()
|
||||
v, _ := m.m.LoadOrStore(k, &value{v: vers(ctx)})
|
||||
v, _ := m.m.LoadOrStore(k, &value{
|
||||
v: vers(ctx),
|
||||
q: make(chan struct{}),
|
||||
l: make(chan struct{}, 1),
|
||||
})
|
||||
return k, v.(*value)
|
||||
}
|
||||
|
|
|
@ -100,13 +100,11 @@ func TestMutex(t *testing.T) {
|
|||
m := new(mutex)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
m.Lock(ctx, new(stringer))
|
||||
ctx = dive(ctx)
|
||||
So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic)
|
||||
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
||||
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
||||
}
|
||||
|
||||
So(nil, ShouldBeNil)
|
||||
|
||||
|
@ -191,7 +189,7 @@ func TestMutex(t *testing.T) {
|
|||
|
||||
})
|
||||
|
||||
Convey("Inability to update the same document in a SELECT subquery", t, func() {
|
||||
Convey("Ability to update the same document in a SELECT subquery", t, func() {
|
||||
|
||||
setupDB(20)
|
||||
|
||||
|
@ -207,12 +205,12 @@ func TestMutex(t *testing.T) {
|
|||
So(res, ShouldHaveLength, 4)
|
||||
So(res[1].Status, ShouldEqual, "OK")
|
||||
So(res[1].Result, ShouldHaveLength, 1)
|
||||
So(res[2].Status, ShouldEqual, "ERR")
|
||||
So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively")
|
||||
So(res[2].Status, ShouldEqual, "OK")
|
||||
So(res[2].Result, ShouldHaveLength, 1)
|
||||
So(res[3].Status, ShouldEqual, "OK")
|
||||
So(res[3].Result, ShouldHaveLength, 1)
|
||||
So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil)
|
||||
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil)
|
||||
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true)
|
||||
|
||||
})
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -147,10 +147,6 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.rlock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.setup(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = d.wlock(ctx); err != nil {
|
||||
if err = d.lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ var (
|
|||
|
||||
// maxRecursiveQueries specifies how many queries will be
|
||||
// processed recursively before the query is cancelled.
|
||||
maxRecursiveQueries = 16
|
||||
maxRecursiveQueries = uint32(16)
|
||||
|
||||
// queryIdentFailed occurs when a permission query asks
|
||||
// for a field, meaning a document has to be fetched.
|
||||
|
|
Loading…
Reference in a new issue