Move to channel based mutex
This commit is contained in:
parent
99d050b238
commit
f335d71aba
12 changed files with 64 additions and 95 deletions
|
@ -18,7 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
func vers(ctx context.Context) int {
|
func vers(ctx context.Context) uint32 {
|
||||||
|
|
||||||
v := ctx.Value(ctxKeyDive)
|
v := ctx.Value(ctxKeyDive)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ func vers(ctx context.Context) int {
|
||||||
case nil:
|
case nil:
|
||||||
return 0
|
return 0
|
||||||
default:
|
default:
|
||||||
return v.(int)
|
return v.(uint32)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -37,12 +37,12 @@ func dive(ctx context.Context) context.Context {
|
||||||
|
|
||||||
switch v {
|
switch v {
|
||||||
case nil:
|
case nil:
|
||||||
return context.WithValue(ctx, ctxKeyDive, 1)
|
return context.WithValue(ctx, ctxKeyDive, uint32(1))
|
||||||
default:
|
default:
|
||||||
if v.(int) > maxRecursiveQueries {
|
if v.(uint32) > maxRecursiveQueries {
|
||||||
panic(errRecursiveOverload)
|
panic(errRecursiveOverload)
|
||||||
}
|
}
|
||||||
return context.WithValue(ctx, ctxKeyDive, v.(int)+1)
|
return context.WithValue(ctx, ctxKeyDive, v.(uint32)+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,13 +36,10 @@ type document struct {
|
||||||
id *sql.Thing
|
id *sql.Thing
|
||||||
key *keys.Thing
|
key *keys.Thing
|
||||||
val kvs.KV
|
val kvs.KV
|
||||||
|
lck bool
|
||||||
doc *data.Doc
|
doc *data.Doc
|
||||||
initial *data.Doc
|
initial *data.Doc
|
||||||
current *data.Doc
|
current *data.Doc
|
||||||
locks struct {
|
|
||||||
r bool
|
|
||||||
w bool
|
|
||||||
}
|
|
||||||
store struct {
|
store struct {
|
||||||
id int
|
id int
|
||||||
tb bool
|
tb bool
|
||||||
|
@ -71,9 +68,7 @@ func newDocument(i *iterator, key *keys.Thing, val kvs.KV, doc *data.Doc) (d *do
|
||||||
d.key = key
|
d.key = key
|
||||||
d.val = val
|
d.val = val
|
||||||
d.doc = doc
|
d.doc = doc
|
||||||
|
d.lck = false
|
||||||
d.locks.r = false
|
|
||||||
d.locks.w = false
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -199,10 +194,10 @@ func (d *document) init(ctx context.Context) (err error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *document) wlock(ctx context.Context) (err error) {
|
func (d *document) lock(ctx context.Context) (err error) {
|
||||||
|
|
||||||
if d.key != nil {
|
if d.key != nil {
|
||||||
d.locks.w = true
|
d.lck = true
|
||||||
d.i.e.lock.Lock(ctx, d.key)
|
d.i.e.lock.Lock(ctx, d.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,29 +205,13 @@ func (d *document) wlock(ctx context.Context) (err error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *document) rlock(ctx context.Context) (err error) {
|
|
||||||
|
|
||||||
if d.key != nil {
|
|
||||||
d.locks.r = true
|
|
||||||
d.i.e.lock.RLock(ctx, d.key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *document) ulock(ctx context.Context) (err error) {
|
func (d *document) ulock(ctx context.Context) (err error) {
|
||||||
|
|
||||||
if d.key != nil && d.locks.w {
|
if d.key != nil && d.lck {
|
||||||
d.locks.w = false
|
d.lck = false
|
||||||
d.i.e.lock.Unlock(ctx, d.key)
|
d.i.e.lock.Unlock(ctx, d.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.key != nil && d.locks.r {
|
|
||||||
d.locks.r = false
|
|
||||||
d.i.e.lock.RUnlock(ctx, d.key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
70
db/mutex.go
70
db/mutex.go
|
@ -23,62 +23,58 @@ import (
|
||||||
|
|
||||||
type mutex struct {
|
type mutex struct {
|
||||||
m sync.Map
|
m sync.Map
|
||||||
l sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type value struct {
|
type value struct {
|
||||||
v int
|
v uint32
|
||||||
r int64
|
q chan struct{}
|
||||||
w int64
|
l chan struct{}
|
||||||
l sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) {
|
func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) {
|
||||||
m.l.Lock()
|
|
||||||
_, v := m.item(ctx, key)
|
_, v := m.item(ctx, key)
|
||||||
if v.v < vers(ctx) {
|
|
||||||
m.l.Unlock()
|
for {
|
||||||
|
select {
|
||||||
|
default:
|
||||||
|
if atomic.LoadUint32(&v.v) < vers(ctx) {
|
||||||
|
close(v.q)
|
||||||
panic(errRaceCondition)
|
panic(errRaceCondition)
|
||||||
}
|
}
|
||||||
atomic.AddInt64(&v.w, 1)
|
case <-ctx.Done():
|
||||||
m.l.Unlock()
|
return
|
||||||
v.l.Lock()
|
case <-v.q:
|
||||||
|
return
|
||||||
|
case v.l <- struct{}{}:
|
||||||
|
atomic.StoreUint32(&v.v, vers(ctx))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mutex) RLock(ctx context.Context, key fmt.Stringer) {
|
|
||||||
m.l.Lock()
|
|
||||||
_, v := m.item(ctx, key)
|
|
||||||
atomic.AddInt64(&v.r, 1)
|
|
||||||
m.l.Unlock()
|
|
||||||
v.l.RLock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) {
|
func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) {
|
||||||
m.l.Lock()
|
|
||||||
defer m.l.Unlock()
|
_, v := m.item(ctx, key)
|
||||||
k, v := m.item(ctx, key)
|
|
||||||
if w := atomic.LoadInt64(&v.w); w > 0 {
|
select {
|
||||||
if w := atomic.AddInt64(&v.w, -1); w <= 0 {
|
case <-ctx.Done():
|
||||||
m.m.Delete(k)
|
return
|
||||||
}
|
case <-v.q:
|
||||||
v.l.Unlock()
|
return
|
||||||
}
|
case <-v.l:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mutex) RUnlock(ctx context.Context, key fmt.Stringer) {
|
|
||||||
m.l.Lock()
|
|
||||||
defer m.l.Unlock()
|
|
||||||
k, v := m.item(ctx, key)
|
|
||||||
if r := atomic.LoadInt64(&v.r); r > 0 {
|
|
||||||
if r := atomic.AddInt64(&v.r, -1); r <= 0 {
|
|
||||||
m.m.Delete(k)
|
|
||||||
}
|
|
||||||
v.l.RUnlock()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) {
|
func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) {
|
||||||
k := key.String()
|
k := key.String()
|
||||||
v, _ := m.m.LoadOrStore(k, &value{v: vers(ctx)})
|
v, _ := m.m.LoadOrStore(k, &value{
|
||||||
|
v: vers(ctx),
|
||||||
|
q: make(chan struct{}),
|
||||||
|
l: make(chan struct{}, 1),
|
||||||
|
})
|
||||||
return k, v.(*value)
|
return k, v.(*value)
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,13 +100,11 @@ func TestMutex(t *testing.T) {
|
||||||
m := new(mutex)
|
m := new(mutex)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
m.Lock(ctx, new(stringer))
|
m.Lock(ctx, new(stringer))
|
||||||
ctx = dive(ctx)
|
ctx = dive(ctx)
|
||||||
So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic)
|
So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic)
|
||||||
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
||||||
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic)
|
||||||
}
|
|
||||||
|
|
||||||
So(nil, ShouldBeNil)
|
So(nil, ShouldBeNil)
|
||||||
|
|
||||||
|
@ -191,7 +189,7 @@ func TestMutex(t *testing.T) {
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("Inability to update the same document in a SELECT subquery", t, func() {
|
Convey("Ability to update the same document in a SELECT subquery", t, func() {
|
||||||
|
|
||||||
setupDB(20)
|
setupDB(20)
|
||||||
|
|
||||||
|
@ -207,12 +205,12 @@ func TestMutex(t *testing.T) {
|
||||||
So(res, ShouldHaveLength, 4)
|
So(res, ShouldHaveLength, 4)
|
||||||
So(res[1].Status, ShouldEqual, "OK")
|
So(res[1].Status, ShouldEqual, "OK")
|
||||||
So(res[1].Result, ShouldHaveLength, 1)
|
So(res[1].Result, ShouldHaveLength, 1)
|
||||||
So(res[2].Status, ShouldEqual, "ERR")
|
So(res[2].Status, ShouldEqual, "OK")
|
||||||
So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively")
|
So(res[2].Result, ShouldHaveLength, 1)
|
||||||
So(res[3].Status, ShouldEqual, "OK")
|
So(res[3].Status, ShouldEqual, "OK")
|
||||||
So(res[3].Result, ShouldHaveLength, 1)
|
So(res[3].Result, ShouldHaveLength, 1)
|
||||||
So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil)
|
So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil)
|
||||||
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil)
|
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true)
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -147,10 +147,6 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.rlock(ctx); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = d.setup(ctx); err != nil {
|
if err = d.setup(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,7 +107,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.wlock(ctx); err != nil {
|
if err = d.lock(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ var (
|
||||||
|
|
||||||
// maxRecursiveQueries specifies how many queries will be
|
// maxRecursiveQueries specifies how many queries will be
|
||||||
// processed recursively before the query is cancelled.
|
// processed recursively before the query is cancelled.
|
||||||
maxRecursiveQueries = 16
|
maxRecursiveQueries = uint32(16)
|
||||||
|
|
||||||
// queryIdentFailed occurs when a permission query asks
|
// queryIdentFailed occurs when a permission query asks
|
||||||
// for a field, meaning a document has to be fetched.
|
// for a field, meaning a document has to be fetched.
|
||||||
|
|
Loading…
Reference in a new issue