diff --git a/db/create.go b/db/create.go index a27c34f5..fae7b153 100644 --- a/db/create.go +++ b/db/create.go @@ -104,9 +104,15 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int var err error var met = _CREATE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -124,11 +130,11 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int return nil, nil } - if err = d.storeIndex(); err != nil { + if err = d.storeIndex(ctx); err != nil { return nil, err } - if err = d.storeThing(); err != nil { + if err = d.storeThing(ctx); err != nil { return nil, err } diff --git a/db/db_test.go b/db/db_test.go index 3e0d1f18..a50fb423 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -22,14 +22,19 @@ import ( "github.com/abcum/surreal/util/data" ) -func setupDB() { +func setupDB(workers ...int) { cnf.Settings = &cnf.Options{} cnf.Settings.DB.Path = "memory" cnf.Settings.DB.Base = "*" cnf.Settings.DB.Proc.Size = 5 - workerCount = 1 + switch len(workers) { + default: + workerCount = workers[0] + case 0: + workerCount = 1 + } Setup(cnf.Settings) diff --git a/db/define_test.go b/db/define_test.go index 9ae48316..a673c90b 100644 --- a/db/define_test.go +++ b/db/define_test.go @@ -26,7 +26,7 @@ func TestDefine(t *testing.T) { Convey("Define a namespace", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -42,7 +42,7 @@ func TestDefine(t *testing.T) { Convey("Define a database", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -58,7 +58,7 @@ func TestDefine(t *testing.T) { Convey("Define a scope", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -91,7 +91,7 @@ func TestDefine(t *testing.T) { Convey("Define a schemaless table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -110,7 +110,7 @@ func TestDefine(t *testing.T) { Convey("Define a schemafull table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -135,7 +135,7 @@ func TestDefine(t *testing.T) { Convey("Define a schemafull table with nil values", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -160,7 +160,7 @@ func TestDefine(t *testing.T) { Convey("Define a schemafull table with nested records", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -185,7 +185,7 @@ func TestDefine(t *testing.T) { Convey("Define a schemafull table with nested set records", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -214,7 +214,7 @@ func TestDefine(t *testing.T) { Convey("Convert a schemaless to schemafull table, and ensure schemaless fields are still output", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -241,7 +241,7 @@ func TestDefine(t *testing.T) { Convey("Define a drop table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -260,7 +260,7 @@ func TestDefine(t *testing.T) { Convey("Define a foreign table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -289,7 +289,7 @@ func TestDefine(t *testing.T) { Convey("Define a table with permission specified so only specified records are visible", t, func() { - setupDB() + setupDB(20) func() { @@ -327,7 +327,7 @@ func TestDefine(t *testing.T) { Convey("Assert the value of a field", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -350,7 +350,7 @@ func TestDefine(t *testing.T) { Convey("Assert the value of a field if it has been set", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -373,7 +373,7 @@ func TestDefine(t *testing.T) { Convey("Specify the priority of a field so that it is processed after any dependent fields", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -401,7 +401,7 @@ func TestDefine(t *testing.T) { Convey("Specify the permissions of a field so that it is only visible to the correct authentication levels", t, func() { - setupDB() + setupDB(20) func() { @@ -496,7 +496,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value changes", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -523,7 +523,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value increases", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -550,7 +550,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value increases beyond a threshold", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -577,7 +577,7 @@ func TestDefine(t *testing.T) { Convey("Define an event for both CREATE and UPDATE events separately", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -607,7 +607,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value changes and set a foreign key on another table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -627,7 +627,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value changes and update a foreign key array on another table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -652,7 +652,7 @@ func TestDefine(t *testing.T) { Convey("Define an event when a value changes and update and delete from a foreign key array on another table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -665,24 +665,26 @@ func TestDefine(t *testing.T) { ); UPDATE person:one SET fk = other:test; UPDATE person:two SET fk = other:test; + UPDATE person:tre SET fk = other:test; DELETE FROM person; SELECT * FROM other; ` res, err := Execute(setupKV(), txt, nil) So(err, ShouldBeNil) - So(res, ShouldHaveLength, 6) + So(res, ShouldHaveLength, 7) So(res[2].Result, ShouldHaveLength, 1) So(res[3].Result, ShouldHaveLength, 1) - So(res[4].Result, ShouldHaveLength, 0) - So(res[5].Result, ShouldHaveLength, 1) - So(data.Consume(res[5].Result[0]).Get("fks").Data(), ShouldHaveLength, 0) + So(res[4].Result, ShouldHaveLength, 1) + So(res[5].Result, ShouldHaveLength, 0) + So(res[6].Result, ShouldHaveLength, 1) + So(data.Consume(res[6].Result[0]).Get("fks").Data(), ShouldHaveLength, 0) }) Convey("Define an event on a table, and ensure it is not output with records", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -699,7 +701,7 @@ func TestDefine(t *testing.T) { Convey("Define an field on a table, and ensure it is not output with records", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -716,7 +718,7 @@ func TestDefine(t *testing.T) { Convey("Define an index on a table, and ensure it is not output with records", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -733,7 +735,7 @@ func TestDefine(t *testing.T) { Convey("Define an index on a table, and ensure it allows duplicate record values", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -769,7 +771,7 @@ func TestDefine(t *testing.T) { Convey("Define a unique index on a table, and ensure it prevents duplicate record values", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -803,7 +805,7 @@ func TestDefine(t *testing.T) { Convey("Redefine a unique index on a table, and ensure it prevents duplicate record values", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; diff --git a/db/delete.go b/db/delete.go index 5a9be43c..12646701 100644 --- a/db/delete.go +++ b/db/delete.go @@ -103,9 +103,15 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int var err error var met = _DELETE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -129,16 +135,16 @@ func (d *document) runDelete(ctx context.Context, stm *sql.DeleteStatement) (int return nil, err } - if err = d.purgeIndex(); err != nil { + if err = d.purgeIndex(ctx); err != nil { return nil, err } if stm.Hard { - if err = d.eraseThing(); err != nil { + if err = d.eraseThing(ctx); err != nil { return nil, err } } else { - if err = d.purgeThing(); err != nil { + if err = d.purgeThing(ctx); err != nil { return nil, err } } diff --git a/db/document.go b/db/document.go index 2b1fda2d..51ff52e7 100644 --- a/db/document.go +++ b/db/document.go @@ -15,6 +15,7 @@ package db import ( + "fmt" "context" @@ -38,7 +39,11 @@ type document struct { doc *data.Doc initial *data.Doc current *data.Doc - store struct { + locks struct { + r bool + w bool + } + store struct { id int tb bool ev bool @@ -67,6 +72,8 @@ func newDocument(i *iterator, key *keys.Thing, val kvs.KV, doc *data.Doc) (d *do d.val = val d.doc = doc + d.locks.r = false + d.locks.w = false return @@ -138,7 +145,22 @@ func (d *document) getLV() (out []*sql.LiveStatement, err error) { return d.cache.lv, err } -func (d *document) query(ctx context.Context, stm sql.Statement) (interface{}, error) { +func (d *document) query(ctx context.Context, stm sql.Statement) (val interface{}, err error) { + + defer func() { + + if r := recover(); r != nil { + var ok bool + if err, ok = r.(error); !ok { + err = fmt.Errorf("%v", r) + } + } + + d.ulock(ctx) + + d.close() + + }() switch stm := stm.(type) { default: @@ -161,7 +183,7 @@ func (d *document) query(ctx context.Context, stm sql.Statement) (interface{}, e } -func (d *document) setup() (err error) { +func (d *document) init(ctx context.Context) (err error) { // A table of records were requested // so we have the values, but no key @@ -173,6 +195,50 @@ func (d *document) setup() (err error) { d.key.Decode(d.val.Key()) } + return + +} + +func (d *document) wlock(ctx context.Context) (err error) { + + if d.key != nil { + d.locks.w = true + d.i.e.lock.Lock(ctx, d.key) + } + + return + +} + +func (d *document) rlock(ctx context.Context) (err error) { + + if d.key != nil { + d.locks.r = true + d.i.e.lock.RLock(ctx, d.key) + } + + return + +} + +func (d *document) ulock(ctx context.Context) (err error) { + + if d.key != nil && d.locks.w { + d.locks.w = false + d.i.e.lock.Unlock(ctx, d.key) + } + + if d.key != nil && d.locks.r { + d.locks.r = false + d.i.e.lock.RUnlock(ctx, d.key) + } + + return + +} + +func (d *document) setup(ctx context.Context) (err error) { + // A specific record has been requested // and we have a key, but no value has // been loaded yet, so the record needs @@ -294,7 +360,9 @@ func (d *document) shouldDrop() (bool, error) { } -func (d *document) storeThing() (err error) { +func (d *document) storeThing(ctx context.Context) (err error) { + + defer d.ulock(ctx) // Check that the table should // drop data being written. @@ -319,7 +387,9 @@ func (d *document) storeThing() (err error) { } -func (d *document) purgeThing() (err error) { +func (d *document) purgeThing(ctx context.Context) (err error) { + + defer d.ulock(ctx) // Check that the table should // drop data being written. @@ -337,7 +407,9 @@ func (d *document) purgeThing() (err error) { } -func (d *document) eraseThing() (err error) { +func (d *document) eraseThing(ctx context.Context) (err error) { + + defer d.ulock(ctx) // Check that the table should // drop data being written. @@ -355,7 +427,7 @@ func (d *document) eraseThing() (err error) { } -func (d *document) storeIndex() (err error) { +func (d *document) storeIndex(ctx context.Context) (err error) { // Check that the table should // drop data being written. @@ -426,7 +498,7 @@ func (d *document) storeIndex() (err error) { } -func (d *document) purgeIndex() (err error) { +func (d *document) purgeIndex(ctx context.Context) (err error) { // Check that the table should // drop data being written. diff --git a/db/executor.go b/db/executor.go index 92c89518..0b75e707 100644 --- a/db/executor.go +++ b/db/executor.go @@ -30,6 +30,7 @@ import ( type executor struct { dbo *mem.Cache time int64 + lock *mutex send chan *Response } @@ -139,6 +140,7 @@ func (e *executor) execute(ctx context.Context, ast *sql.Query) { switch stm.(type) { case *sql.BeginStatement: + e.lock = new(mutex) err = e.begin(ctx, true) continue case *sql.CancelStatement: @@ -239,6 +241,12 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf defer e.dbo.Cancel() + // Let's create a new mutex for just this + // local transaction, so we can track any + // recursive queries and race errors. + + e.lock = new(mutex) + } // Mark the beginning of this statement so we @@ -421,6 +429,7 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf func (e *executor) begin(ctx context.Context, rw bool) (err error) { if e.dbo.TX == nil { + e.dbo = mem.New() e.dbo.TX, err = db.Begin(ctx, rw) } return diff --git a/db/fetch.go b/db/fetch.go index c81e5b60..fd476126 100644 --- a/db/fetch.go +++ b/db/fetch.go @@ -379,6 +379,27 @@ func (e *executor) fetchArray(ctx context.Context, val []interface{}, doc *data. } +func (e *executor) fetchPerms(ctx context.Context, val sql.Expr, tb *sql.Ident) error { + + res, err := e.fetch(ctx, val, ign) + + // If we receive an 'ident failed' error + // it is because the table permission + // expression contains a field check, + // and therefore we must check each + // record individually to see if it can + // be accessed or not. + + if err != queryIdentFailed { + if res, ok := res.(bool); ok && !res { + return &PermsError{table: tb.ID} + } + } + + return nil + +} + func (e *executor) fetchLimit(ctx context.Context, val sql.Expr) (int, error) { v, err := e.fetch(ctx, val, nil) diff --git a/db/insert.go b/db/insert.go index e5064967..d05a1625 100644 --- a/db/insert.go +++ b/db/insert.go @@ -82,9 +82,15 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int var err error var met = _CREATE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -102,11 +108,11 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int return nil, nil } - if err = d.storeIndex(); err != nil { + if err = d.storeIndex(ctx); err != nil { return nil, err } - if err = d.storeThing(); err != nil { + if err = d.storeThing(ctx); err != nil { return nil, err } diff --git a/db/iterator.go b/db/iterator.go index 338aac38..3026ab48 100644 --- a/db/iterator.go +++ b/db/iterator.go @@ -45,6 +45,12 @@ type iterator struct { stm sql.Statement res []interface{} + wait sync.WaitGroup + fail chan error + stop chan struct{} + jobs chan *workable + vals chan *doneable + expr sql.Fields what sql.Exprs cond sql.Expr @@ -54,14 +60,6 @@ type iterator struct { start int versn int64 tasks int - - wait sync.WaitGroup - fail chan error - full chan struct{} - stop chan struct{} - done chan struct{} - jobs chan *workable - vals chan *doneable } type workable struct { @@ -99,8 +97,7 @@ func newIterator(e *executor, ctx context.Context, stm sql.Statement, vir bool) i.res = make([]interface{}, 0) i.wait = sync.WaitGroup{} - i.fail = make(chan error) - i.full = make(chan struct{}) + i.fail = make(chan error, 1) i.stop = make(chan struct{}) i.jobs = make(chan *workable, 1000) i.vals = make(chan *doneable, 1000) @@ -115,7 +112,7 @@ func newIterator(e *executor, ctx context.Context, stm sql.Statement, vir bool) // Comment here ... - i.checkWorker(ctx) + i.watchVals(ctx) return @@ -128,6 +125,11 @@ func (i *iterator) Close() { i.stm = nil i.res = nil + i.fail = nil + i.stop = nil + i.jobs = nil + i.vals = nil + i.expr = nil i.what = nil i.cond = nil @@ -222,8 +224,6 @@ func (i *iterator) checkState(ctx context.Context) bool { return false case <-i.stop: return false - case <-i.full: - return false default: return true } @@ -269,35 +269,17 @@ func (i *iterator) submitTask(key *keys.Thing, val kvs.KV, doc *data.Doc) { } -func (i *iterator) checkWorker(ctx context.Context) { +func (i *iterator) watchVals(ctx context.Context) { - go func(fail chan error) { - for err := range fail { - i.receivedError(err) - } - }(i.fail) - - go func(vals chan *doneable) { + go func(vals <-chan *doneable) { for val := range vals { - i.receivedResult(val) + i.receive(val) } }(i.vals) } -func (i *iterator) receivedError(err error) { - - select { - case <-i.stop: - return - default: - i.err = err - close(i.stop) - } - -} - -func (i *iterator) receivedResult(val *doneable) { +func (i *iterator) receive(val *doneable) { defer i.wait.Done() @@ -310,8 +292,9 @@ func (i *iterator) receivedResult(val *doneable) { case <-i.stop: return default: - i.err = val.err + i.fail <- val.err close(i.stop) + return } } @@ -356,16 +339,16 @@ func (i *iterator) receivedResult(val *doneable) { // query statement. select { - case <-i.full: + case <-i.stop: return default: if i.start >= 0 { if len(i.res) == i.limit+i.start { - close(i.full) + close(i.stop) } } else { if len(i.res) == i.limit { - close(i.full) + close(i.stop) } } } @@ -374,8 +357,6 @@ func (i *iterator) receivedResult(val *doneable) { func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { - var err error - var tb *sql.DefineTableStatement // If we are authenticated using DB, NS, @@ -407,35 +388,39 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { // we need to fetch the table to ensure // that the table is not a view table. - tb, err = i.e.dbo.AddTB(nsv, dbv, tbv) - if err != nil { - i.fail <- err + tb, i.err = i.e.dbo.AddTB(nsv, dbv, tbv) + if i.err != nil { + close(i.stop) return } // If the table is locked (because it // has been specified as a view), then // check to see what query type it is - // and return an error, if it attempts + // and return an error if it attempts // to alter the table in any way. if tb.Lock && i.vir == false { switch i.stm.(type) { case *sql.CreateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.UpdateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.DeleteStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.RelateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.InsertStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.UpsertStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} } } + if i.err != nil { + close(i.stop) + } + return } @@ -444,9 +429,9 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { // otherwise, the scoped authentication // request can not do anything. - _, err = i.e.dbo.GetNS(nsv) - if err != nil { - i.fail <- err + _, i.err = i.e.dbo.GetNS(nsv) + if i.err != nil { + close(i.stop) return } @@ -454,9 +439,9 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { // otherwise, the scoped authentication // request can not do anything. - _, err = i.e.dbo.GetDB(nsv, dbv) - if err != nil { - i.fail <- err + _, i.err = i.e.dbo.GetDB(nsv, dbv) + if i.err != nil { + close(i.stop) return } @@ -473,9 +458,9 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { // otherwise, the scoped authentication // request can not do anything. - tb, err = i.e.dbo.GetTB(nsv, dbv, tbv) - if err != nil { - i.fail <- err + tb, i.err = i.e.dbo.GetTB(nsv, dbv, tbv) + if i.err != nil { + close(i.stop) return } @@ -488,20 +473,25 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { if tb.Lock && i.vir == false { switch i.stm.(type) { case *sql.CreateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.UpdateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.DeleteStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.RelateStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.InsertStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} case *sql.UpsertStatement: - i.fail <- &TableError{table: tb.Name.ID} + i.err = &TableError{table: tb.Name.ID} } } + if i.err != nil { + close(i.stop) + return + } + // If the table does exist we reset the // context to DB level so that no other // embedded permissions are checked on @@ -514,43 +504,31 @@ func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) { // expression, but only if they don't // reference any document fields. - var val interface{} - switch p := tb.Perms.(type) { + default: + i.err = &PermsError{table: tb.Name.ID} case *sql.PermExpression: switch i.stm.(type) { case *sql.SelectStatement: - val, err = i.e.fetch(ctx, p.Select, ign) + i.err = i.e.fetchPerms(ctx, p.Select, tb.Name) case *sql.CreateStatement: - val, err = i.e.fetch(ctx, p.Create, ign) + i.err = i.e.fetchPerms(ctx, p.Create, tb.Name) case *sql.UpdateStatement: - val, err = i.e.fetch(ctx, p.Update, ign) + i.err = i.e.fetchPerms(ctx, p.Update, tb.Name) case *sql.DeleteStatement: - val, err = i.e.fetch(ctx, p.Delete, ign) + i.err = i.e.fetchPerms(ctx, p.Delete, tb.Name) case *sql.RelateStatement: - val, err = i.e.fetch(ctx, p.Create, ign) + i.err = i.e.fetchPerms(ctx, p.Create, tb.Name) case *sql.InsertStatement: - val, err = i.e.fetch(ctx, p.Create, ign) + i.err = i.e.fetchPerms(ctx, p.Create, tb.Name) case *sql.UpsertStatement: - val, err = i.e.fetch(ctx, p.Update, ign) + i.err = i.e.fetchPerms(ctx, p.Update, tb.Name) } - default: - i.fail <- &PermsError{table: tb.Name.ID} - return } - // If we receive an 'ident failed' error - // it is because the table permission - // expression contains a field check, - // and therefore we must check each - // record individually to see if it can - // be accessed or not. - - if err != queryIdentFailed { - if val, ok := val.(bool); ok && !val { - i.fail <- &PermsError{table: tb.Name.ID} - return - } + if i.err != nil { + close(i.stop) + return } return @@ -585,13 +563,15 @@ func (i *iterator) processTable(ctx context.Context, key *keys.Table) { for x := 0; ; x = 1 { + var vals []kvs.KV + if !i.checkState(ctx) { return } - vals, err := i.e.dbo.GetR(i.versn, min, max, 10000) - if err != nil { - i.fail <- err + vals, i.err = i.e.dbo.GetR(i.versn, min, max, 10000) + if i.err != nil { + close(i.stop) return } @@ -840,14 +820,22 @@ func (i *iterator) Yield(ctx context.Context) (out []interface{}, err error) { defer i.Close() i.wait.Wait() + close(i.jobs) - close(i.fail) close(i.vals) if i.err != nil { return nil, i.err } + if i.err == nil { + select { + default: + case i.err = <-i.fail: + return nil, i.err + } + } + if len(i.group) > 0 { i.res = i.Group(ctx, i.res) } diff --git a/db/live.go b/db/live.go index 258bee2b..87714e12 100644 --- a/db/live.go +++ b/db/live.go @@ -23,7 +23,7 @@ import ( "github.com/abcum/surreal/sql" ) -var lock sync.Mutex +var locker sync.Mutex var sockets map[string]*socket @@ -34,8 +34,8 @@ func init() { func register(fib *fibre.Context, id string) func() { return func() { - lock.Lock() - defer lock.Unlock() + locker.Lock() + defer locker.Unlock() sockets[id] = &socket{ fibre: fib, @@ -49,8 +49,8 @@ func register(fib *fibre.Context, id string) func() { func deregister(fib *fibre.Context, id string) func() { return func() { - lock.Lock() - defer lock.Unlock() + locker.Lock() + defer locker.Unlock() if sck, ok := sockets[id]; ok { sck.deregister(id) diff --git a/db/mutex.go b/db/mutex.go new file mode 100644 index 00000000..abcab734 --- /dev/null +++ b/db/mutex.go @@ -0,0 +1,84 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" + "fmt" + "sync" + "sync/atomic" +) + +type mutex struct { + m sync.Map + l sync.Mutex +} + +type value struct { + v int + r int64 + w int64 + l sync.RWMutex +} + +func (m *mutex) Lock(ctx context.Context, key fmt.Stringer) { + m.l.Lock() + _, v := m.item(ctx, key) + if v.v < vers(ctx) { + m.l.Unlock() + panic(errRaceCondition) + } + atomic.AddInt64(&v.w, 1) + m.l.Unlock() + v.l.Lock() +} + +func (m *mutex) RLock(ctx context.Context, key fmt.Stringer) { + m.l.Lock() + _, v := m.item(ctx, key) + atomic.AddInt64(&v.r, 1) + m.l.Unlock() + v.l.RLock() +} + +func (m *mutex) Unlock(ctx context.Context, key fmt.Stringer) { + m.l.Lock() + defer m.l.Unlock() + k, v := m.item(ctx, key) + if w := atomic.LoadInt64(&v.w); w > 0 { + if w := atomic.AddInt64(&v.w, -1); w <= 0 { + m.m.Delete(k) + } + v.l.Unlock() + } +} + +func (m *mutex) RUnlock(ctx context.Context, key fmt.Stringer) { + m.l.Lock() + defer m.l.Unlock() + k, v := m.item(ctx, key) + if r := atomic.LoadInt64(&v.r); r > 0 { + if r := atomic.AddInt64(&v.r, -1); r <= 0 { + m.m.Delete(k) + } + v.l.RUnlock() + } +} + +func (m *mutex) item(ctx context.Context, key fmt.Stringer) (string, *value) { + k := key.String() + v, _ := m.m.LoadOrStore(k, &value{v: vers(ctx)}) + return k, v.(*value) +} diff --git a/db/mutex_test.go b/db/mutex_test.go new file mode 100644 index 00000000..5cf9513f --- /dev/null +++ b/db/mutex_test.go @@ -0,0 +1,388 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" + "sync" + "testing" + + "github.com/abcum/surreal/util/data" + . "github.com/smartystreets/goconvey/convey" +) + +type stringer struct{} + +func (this stringer) String() string { + return "test" +} + +func TestMutex(t *testing.T) { + + var n = 10 + + Convey("Context diving works correctly", t, func() { + + ctx := context.Background() + + So(vers(ctx), ShouldEqual, 0) + + for i := vers(ctx); i <= maxRecursiveQueries; i++ { + So(func() { ctx = dive(ctx) }, ShouldNotPanic) + So(vers(ctx), ShouldEqual, i+1) + } + + So(func() { dive(ctx) }, ShouldPanicWith, errRecursiveOverload) + + }) + + Convey("Allow basic mutex", t, func() { + + m := new(mutex) + ctx := context.Background() + + m.Lock(ctx, new(stringer)) + m.Unlock(ctx, new(stringer)) + + }) + + Convey("Allow concurrent mutex", t, func() { + + m := new(mutex) + wg := new(sync.WaitGroup) + ctx := context.Background() + + wg.Add(n) + + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + m.Lock(ctx, new(stringer)) + m.Unlock(ctx, new(stringer)) + }() + } + + wg.Wait() + + So(nil, ShouldBeNil) + + }) + + Convey("Allow fixed-level mutex", t, func() { + + m := new(mutex) + ctx := context.Background() + + for i := 0; i < n; i++ { + ctx = dive(ctx) + So(func() { m.Lock(ctx, new(stringer)) }, ShouldNotPanic) + So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) + } + + So(nil, ShouldBeNil) + + }) + + Convey("Prevent nested-recursive mutex", t, func() { + + m := new(mutex) + ctx := context.Background() + + for i := 0; i < n; i++ { + m.Lock(ctx, new(stringer)) + ctx = dive(ctx) + So(func() { m.Lock(ctx, new(stringer)) }, ShouldPanic) + So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) + So(func() { m.Unlock(ctx, new(stringer)) }, ShouldNotPanic) + } + + So(nil, ShouldBeNil) + + }) + + Convey("Ensure document locking when multiple events attempt to write to the same document", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + DEFINE EVENT created ON person WHEN $method = "CREATE" THEN (UPDATE $after.fk SET fks += $this); + DEFINE EVENT deleted ON person WHEN $method = "DELETE" THEN (UPDATE $before.fk SET fks -= $this); + UPDATE |person:1..100| SET fk = other:test; + SELECT * FROM other; + DELETE FROM person; + SELECT * FROM other; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 7) + So(res[1].Status, ShouldEqual, "OK") + So(res[2].Status, ShouldEqual, "OK") + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 100) + So(res[4].Status, ShouldEqual, "OK") + So(res[4].Result, ShouldHaveLength, 1) + So(data.Consume(res[4].Result[0]).Get("fks").Data(), ShouldHaveLength, 100) + So(res[5].Status, ShouldEqual, "OK") + So(res[5].Result, ShouldHaveLength, 0) + So(res[6].Status, ShouldEqual, "OK") + So(res[6].Result, ShouldHaveLength, 1) + So(data.Consume(res[6].Result[0]).Get("fks").Data(), ShouldHaveLength, 0) + + }) + + Convey("Ability to select the same document in a SELECT subquery", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + CREATE person:test; + SELECT * FROM (SELECT * FROM (SELECT * FROM person)); + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Status, ShouldEqual, "OK") + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Status, ShouldEqual, "OK") + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 1) + + }) + + Convey("Ability to update the same document in a SELECT subquery", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + CREATE person:test; + SELECT * FROM (UPDATE person SET test=true); + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Status, ShouldEqual, "OK") + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Status, ShouldEqual, "OK") + So(data.Consume(res[2].Result[0]).Get("temp").Data(), ShouldBeNil) + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 1) + So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true) + + }) + + Convey("Inability to update the same document in a SELECT subquery", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + CREATE person:test; + SELECT *, (UPDATE $parent.id SET test=true) AS test FROM person; + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Status, ShouldEqual, "OK") + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Status, ShouldEqual, "ERR") + So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively") + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 1) + So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil) + + }) + + Convey("Inability to update the same document in an UPDATE subquery", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + CREATE person:test; + UPDATE person SET temp = (UPDATE person SET test=true); + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Status, ShouldEqual, "OK") + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Status, ShouldEqual, "ERR") + So(res[2].Detail, ShouldEqual, "Failed to update the same document recursively") + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 1) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldBeNil) + So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldBeNil) + + }) + + Convey("Ability to update the same document in an event", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + DEFINE EVENT test ON person WHEN $before.test != $after.test THEN (UPDATE $this SET temp = true); + UPDATE person:test SET test=true; + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Status, ShouldEqual, "OK") + So(res[2].Status, ShouldEqual, "OK") + So(data.Consume(res[2].Result[0]).Get("temp").Data(), ShouldBeNil) + So(data.Consume(res[2].Result[0]).Get("test").Data(), ShouldEqual, true) + So(res[3].Status, ShouldEqual, "OK") + So(data.Consume(res[3].Result[0]).Get("temp").Data(), ShouldEqual, true) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, true) + + }) + + Convey("Subqueries for an event should be on the same level", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + DEFINE EVENT test ON person WHEN $method = "CREATE" THEN (CREATE tester); + CREATE |person:100|; + SELECT * FROM person; + SELECT * FROM tester; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 5) + So(res[1].Status, ShouldEqual, "OK") + So(res[2].Status, ShouldEqual, "OK") + So(res[2].Result, ShouldHaveLength, 100) + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 100) + So(res[4].Status, ShouldEqual, "OK") + So(res[4].Result, ShouldHaveLength, 100) + + }) + + Convey("Subqueries for an event on a different level create an infinite loop", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + DEFINE EVENT test ON person WHEN $method = "CREATE" THEN (CREATE person); + CREATE person:test; + SELECT * FROM person; + SELECT * FROM tester; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 5) + So(res[1].Status, ShouldEqual, "OK") + So(res[2].Status, ShouldEqual, "ERR") + So(res[2].Detail, ShouldEqual, "Infinite loop when running recursive subqueries") + So(res[3].Status, ShouldEqual, "OK") + So(res[3].Result, ShouldHaveLength, 0) + So(res[4].Status, ShouldEqual, "OK") + So(res[4].Result, ShouldHaveLength, 0) + + }) + + Convey("Subqueries for recursive events on a different level create an infinite loop", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + DEFINE EVENT test ON person WHEN $method = "UPDATE" THEN (UPDATE tester SET temp=time.now()); + DEFINE EVENT test ON tester WHEN $method = "UPDATE" THEN (UPDATE person SET temp=time.now()); + CREATE person:test, tester:test SET temp=time.now(); + UPDATE person:test SET temp=time.now(); + SELECT * FROM person; + SELECT * FROM tester; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 7) + So(res[1].Status, ShouldEqual, "OK") + So(res[2].Status, ShouldEqual, "OK") + So(res[3].Status, ShouldEqual, "OK") + So(res[4].Status, ShouldEqual, "ERR") + So(res[4].Detail, ShouldEqual, "Infinite loop when running recursive subqueries") + So(res[5].Status, ShouldEqual, "OK") + So(res[5].Result, ShouldHaveLength, 1) + So(res[6].Status, ShouldEqual, "OK") + So(res[6].Result, ShouldHaveLength, 1) + + }) + + Convey("Ability to define complex dependent events which should run consecutively and succeed", t, func() { + + setupDB(20) + + txt := ` + USE NS test DB test; + CREATE global:test SET tests=[], temps=[]; + DEFINE EVENT test ON tester WHEN $after.global != EMPTY THEN ( + UPDATE $after.global SET tests+=$this; + UPDATE temper SET tester=$this, global=$after.global; + ); + DEFINE EVENT test ON temper WHEN $after.global != EMPTY THEN ( + UPDATE $after.global SET temps+=$this; + ); + CREATE |temper:1..5|; + CREATE tester:test SET global=global:test; + SELECT * FROM global; + SELECT * FROM tester; + SELECT * FROM temper; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 9) + So(res[1].Status, ShouldEqual, "OK") + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Status, ShouldEqual, "OK") + So(res[3].Status, ShouldEqual, "OK") + So(res[4].Status, ShouldEqual, "OK") + So(res[4].Result, ShouldHaveLength, 5) + So(res[5].Status, ShouldEqual, "OK") + So(res[5].Result, ShouldHaveLength, 1) + So(res[6].Status, ShouldEqual, "OK") + So(res[6].Result, ShouldHaveLength, 1) + So(res[7].Status, ShouldEqual, "OK") + So(res[7].Result, ShouldHaveLength, 1) + So(res[8].Status, ShouldEqual, "OK") + So(res[8].Result, ShouldHaveLength, 5) + + }) + +} diff --git a/db/relate.go b/db/relate.go index ef8442aa..c7575c76 100644 --- a/db/relate.go +++ b/db/relate.go @@ -96,9 +96,15 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int var err error var met = _CREATE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -116,11 +122,11 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int return nil, nil } - if err = d.storeIndex(); err != nil { + if err = d.storeIndex(ctx); err != nil { return nil, err } - if err = d.storeThing(); err != nil { + if err = d.storeThing(ctx); err != nil { return nil, err } diff --git a/db/remove_test.go b/db/remove_test.go index f3406708..516edd45 100644 --- a/db/remove_test.go +++ b/db/remove_test.go @@ -25,7 +25,7 @@ func TestRemove(t *testing.T) { Convey("Remove a namespace", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -46,7 +46,7 @@ func TestRemove(t *testing.T) { Convey("Remove a database", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; @@ -71,7 +71,7 @@ func TestRemove(t *testing.T) { Convey("Remove a table", t, func() { - setupDB() + setupDB(20) txt := ` USE NS test DB test; diff --git a/db/select.go b/db/select.go index 231e2d2f..80b4c5b7 100644 --- a/db/select.go +++ b/db/select.go @@ -141,9 +141,15 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int var err error var met = _SELECT - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.rlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } diff --git a/db/update.go b/db/update.go index 0e93d3b2..d747d30b 100644 --- a/db/update.go +++ b/db/update.go @@ -103,9 +103,15 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int var err error var met = _UPDATE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -129,11 +135,11 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int return nil, err } - if err = d.storeIndex(); err != nil { + if err = d.storeIndex(ctx); err != nil { return nil, err } - if err = d.storeThing(); err != nil { + if err = d.storeThing(ctx); err != nil { return nil, err } diff --git a/db/upsert.go b/db/upsert.go index 2b9fc4dd..c30cb47b 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -82,9 +82,15 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int var err error var met = _UPDATE - defer d.close() + if err = d.init(ctx); err != nil { + return nil, err + } - if err = d.setup(); err != nil { + if err = d.wlock(ctx); err != nil { + return nil, err + } + + if err = d.setup(ctx); err != nil { return nil, err } @@ -102,11 +108,11 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int return nil, err } - if err = d.storeIndex(); err != nil { + if err = d.storeIndex(ctx); err != nil { return nil, err } - if err = d.storeThing(); err != nil { + if err = d.storeThing(ctx); err != nil { return nil, err } diff --git a/db/vars.go b/db/vars.go index 029e724c..e8fc7b83 100644 --- a/db/vars.go +++ b/db/vars.go @@ -66,7 +66,7 @@ const ( var ( // maxWorkers enables limiting the maximum number of // workers to start, regardless of the CPU count. - maxWorkers = 1 + maxWorkers = 20 // workerCount specifies how many workers should be used // to process each query statement concurrently. @@ -74,7 +74,7 @@ var ( // maxRecursiveQueries specifies how many queries will be // processed recursively before the query is cancelled. - maxRecursiveQueries = 50 + maxRecursiveQueries = 16 // queryIdentFailed occurs when a permission query asks // for a field, meaning a document has to be fetched. diff --git a/mem/mem.go b/mem/mem.go index a5a91b2d..69414891 100644 --- a/mem/mem.go +++ b/mem/mem.go @@ -25,8 +25,13 @@ import ( type Cache struct { kvs.TX - lock sync.RWMutex - data map[string]interface{} + lock sync.RWMutex + data map[string]interface{} + locks struct { + ns sync.RWMutex + db sync.RWMutex + tb sync.RWMutex + } } func New() (c *Cache) { @@ -46,22 +51,22 @@ func (c *Cache) Reset() { c.TX = nil } -func (c *Cache) get(idx string) (out interface{}, ok bool) { +func (c *Cache) get(key keys.Key) (out interface{}, ok bool) { c.lock.RLock() - out, ok = c.data[idx] + out, ok = c.data[key.String()] c.lock.RUnlock() return } -func (c *Cache) put(idx string, val interface{}) { +func (c *Cache) put(key keys.Key, val interface{}) { c.lock.Lock() - c.data[idx] = val + c.data[key.String()] = val c.lock.Unlock() } -func (c *Cache) del(idx string) { +func (c *Cache) del(key keys.Key) { c.lock.Lock() - delete(c.data, idx) + delete(c.data, key.String()) c.lock.Unlock() } @@ -69,15 +74,17 @@ func (c *Cache) del(idx string) { func (c *Cache) AllNS() (out []*sql.DefineNamespaceStatement, err error) { - idx := (&keys.KV{}).String() + var kvs []kvs.KV - if out, ok := c.get(idx); ok { + c.locks.ns.RLock() + defer c.locks.ns.RUnlock() + + key := &keys.NS{KV: cnf.Settings.DB.Base, NS: keys.Ignore} + + if out, ok := c.get(key); ok { return out.([]*sql.DefineNamespaceStatement), nil } - var kvs []kvs.KV - - key := &keys.NS{KV: cnf.Settings.DB.Base, NS: keys.Ignore} if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -88,7 +95,7 @@ func (c *Cache) AllNS() (out []*sql.DefineNamespaceStatement, err error) { out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -96,15 +103,17 @@ func (c *Cache) AllNS() (out []*sql.DefineNamespaceStatement, err error) { func (c *Cache) GetNS(ns string) (val *sql.DefineNamespaceStatement, err error) { - idx := (&keys.NS{NS: ns}).String() + var kv kvs.KV - if out, ok := c.get(idx); ok { + c.locks.ns.RLock() + defer c.locks.ns.RUnlock() + + key := &keys.NS{KV: cnf.Settings.DB.Base, NS: ns} + + if out, ok := c.get(key); ok { return out.(*sql.DefineNamespaceStatement), nil } - var kv kvs.KV - - key := &keys.NS{KV: cnf.Settings.DB.Base, NS: ns} if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -116,43 +125,46 @@ func (c *Cache) GetNS(ns string) (val *sql.DefineNamespaceStatement, err error) val = &sql.DefineNamespaceStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return } -func (c *Cache) AddNS(ns string) (*sql.DefineNamespaceStatement, error) { +func (c *Cache) AddNS(ns string) (val *sql.DefineNamespaceStatement, err error) { - idx := (&keys.NS{NS: ns}).String() + var kv kvs.KV - if out, ok := c.get(idx); ok { + c.locks.ns.Lock() + defer c.locks.ns.Unlock() + + key := &keys.NS{KV: cnf.Settings.DB.Base, NS: ns} + + if out, ok := c.get(key); ok { return out.(*sql.DefineNamespaceStatement), nil } - if out, err := c.GetNS(ns); err == nil { - return out, nil + if kv, _ = c.TX.Get(0, key.Encode()); kv.Exi() { + val = &sql.DefineNamespaceStatement{} + val.Decode(kv.Val()) + c.put(key, val) + return } - key := &keys.NS{KV: cnf.Settings.DB.Base, NS: ns} - val := &sql.DefineNamespaceStatement{Name: sql.NewIdent(ns)} - if _, err := c.TX.PutC(0, key.Encode(), val.Encode(), nil); err != nil { - return nil, err - } + val = &sql.DefineNamespaceStatement{Name: sql.NewIdent(ns)} + c.TX.PutC(0, key.Encode(), val.Encode(), nil) - c.put(idx, val) + c.put(key, val) - return val, nil + return } func (c *Cache) DelNS(ns string) { - c.del((&keys.NS{NS: keys.Ignore}).String()) + c.del(&keys.NS{KV: cnf.Settings.DB.Base, NS: keys.Ignore}) - c.del((&keys.NS{NS: ns}).String()) - - return + c.del(&keys.NS{KV: cnf.Settings.DB.Base, NS: ns}) } @@ -242,15 +254,17 @@ func (c *Cache) GetNU(ns, us string) (val *sql.DefineLoginStatement, err error) func (c *Cache) AllDB(ns string) (out []*sql.DefineDatabaseStatement, err error) { - idx := (&keys.DB{NS: ns, DB: keys.Ignore}).String() + var kvs []kvs.KV - if out, ok := c.get(idx); ok { + c.locks.db.RLock() + defer c.locks.db.RUnlock() + + key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: keys.Ignore} + + if out, ok := c.get(key); ok { return out.([]*sql.DefineDatabaseStatement), nil } - var kvs []kvs.KV - - key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: keys.Ignore} if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -261,7 +275,7 @@ func (c *Cache) AllDB(ns string) (out []*sql.DefineDatabaseStatement, err error) out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -269,15 +283,17 @@ func (c *Cache) AllDB(ns string) (out []*sql.DefineDatabaseStatement, err error) func (c *Cache) GetDB(ns, db string) (val *sql.DefineDatabaseStatement, err error) { - idx := (&keys.DB{NS: ns, DB: db}).String() + var kv kvs.KV - if out, ok := c.get(idx); ok { + c.locks.db.RLock() + defer c.locks.db.RUnlock() + + key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: db} + + if out, ok := c.get(key); ok { return out.(*sql.DefineDatabaseStatement), nil } - var kv kvs.KV - - key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: db} if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -289,47 +305,50 @@ func (c *Cache) GetDB(ns, db string) (val *sql.DefineDatabaseStatement, err erro val = &sql.DefineDatabaseStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return } -func (c *Cache) AddDB(ns, db string) (*sql.DefineDatabaseStatement, error) { +func (c *Cache) AddDB(ns, db string) (val *sql.DefineDatabaseStatement, err error) { - idx := (&keys.DB{NS: ns, DB: db}).String() + if _, err = c.AddNS(ns); err != nil { + return + } - if out, ok := c.get(idx); ok { + var kv kvs.KV + + c.locks.db.Lock() + defer c.locks.db.Unlock() + + key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: db} + + if out, ok := c.get(key); ok { return out.(*sql.DefineDatabaseStatement), nil } - if out, err := c.GetDB(ns, db); err == nil { - return out, nil + if kv, _ = c.TX.Get(0, key.Encode()); kv.Exi() { + val = &sql.DefineDatabaseStatement{} + val.Decode(kv.Val()) + c.put(key, val) + return } - if _, err := c.AddNS(ns); err != nil { - return nil, err - } + val = &sql.DefineDatabaseStatement{Name: sql.NewIdent(db)} + c.TX.PutC(0, key.Encode(), val.Encode(), nil) - key := &keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: db} - val := &sql.DefineDatabaseStatement{Name: sql.NewIdent(db)} - if _, err := c.TX.PutC(0, key.Encode(), val.Encode(), nil); err != nil { - return nil, err - } + c.put(key, val) - c.put(idx, val) - - return val, nil + return } func (c *Cache) DelDB(ns, db string) { - c.del((&keys.DB{NS: ns, DB: keys.Ignore}).String()) + c.del(&keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: keys.Ignore}) - c.del((&keys.DB{NS: ns, DB: db}).String()) - - return + c.del(&keys.DB{KV: cnf.Settings.DB.Base, NS: ns, DB: db}) } @@ -501,15 +520,17 @@ func (c *Cache) GetST(ns, db, sc, tk string) (val *sql.DefineTokenStatement, err func (c *Cache) AllTB(ns, db string) (out []*sql.DefineTableStatement, err error) { - idx := (&keys.TB{NS: ns, DB: db, TB: keys.Ignore}).String() + var kvs []kvs.KV - if out, ok := c.get(idx); ok { + c.locks.tb.RLock() + defer c.locks.tb.RUnlock() + + key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: keys.Ignore} + + if out, ok := c.get(key); ok { return out.([]*sql.DefineTableStatement), nil } - var kvs []kvs.KV - - key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: keys.Ignore} if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -520,7 +541,7 @@ func (c *Cache) AllTB(ns, db string) (out []*sql.DefineTableStatement, err error out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -528,15 +549,17 @@ func (c *Cache) AllTB(ns, db string) (out []*sql.DefineTableStatement, err error func (c *Cache) GetTB(ns, db, tb string) (val *sql.DefineTableStatement, err error) { - idx := (&keys.TB{NS: ns, DB: db, TB: tb}).String() + var kv kvs.KV - if out, ok := c.get(idx); ok { + c.locks.tb.RLock() + defer c.locks.tb.RUnlock() + + key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb} + + if out, ok := c.get(key); ok { return out.(*sql.DefineTableStatement), nil } - var kv kvs.KV - - key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb} if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -548,49 +571,50 @@ func (c *Cache) GetTB(ns, db, tb string) (val *sql.DefineTableStatement, err err val = &sql.DefineTableStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return } -func (c *Cache) AddTB(ns, db, tb string) (*sql.DefineTableStatement, error) { +func (c *Cache) AddTB(ns, db, tb string) (val *sql.DefineTableStatement, err error) { - // var exi bool + if _, err = c.AddDB(ns, db); err != nil { + return + } - idx := (&keys.TB{NS: ns, DB: db, TB: tb}).String() + var kv kvs.KV - if out, ok := c.get(idx); ok { + c.locks.tb.Lock() + defer c.locks.tb.Unlock() + + key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb} + + if out, ok := c.get(key); ok { return out.(*sql.DefineTableStatement), nil } - if out, err := c.GetTB(ns, db, tb); err == nil { - return out, nil + if kv, _ = c.TX.Get(0, key.Encode()); kv.Exi() { + val = &sql.DefineTableStatement{} + val.Decode(kv.Val()) + c.put(key, val) + return } - if _, err := c.AddDB(ns, db); err != nil { - return nil, err - } + val = &sql.DefineTableStatement{Name: sql.NewIdent(tb)} + c.TX.PutC(0, key.Encode(), val.Encode(), nil) - key := &keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb} - val := &sql.DefineTableStatement{Name: sql.NewIdent(tb)} - if _, err := c.TX.PutC(0, key.Encode(), val.Encode(), nil); err != nil { - return nil, err - } + c.put(key, val) - c.put(idx, val) - - return val, nil + return } func (c *Cache) DelTB(ns, db, tb string) { - c.del((&keys.TB{NS: ns, DB: db, TB: keys.Ignore}).String()) + c.del(&keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: keys.Ignore}) - c.del((&keys.TB{NS: ns, DB: db, TB: tb}).String()) - - return + c.del(&keys.TB{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb}) } @@ -598,15 +622,14 @@ func (c *Cache) DelTB(ns, db, tb string) { func (c *Cache) AllEV(ns, db, tb string) (out []*sql.DefineEventStatement, err error) { - idx := (&keys.EV{NS: ns, DB: db, TB: tb, EV: keys.Ignore}).String() - - if out, ok := c.get(idx); ok { - return out.([]*sql.DefineEventStatement), nil - } - var kvs []kvs.KV key := &keys.EV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, EV: keys.Ignore} + + if out, ok := c.get(key); ok { + return out.([]*sql.DefineEventStatement), nil + } + if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -617,7 +640,7 @@ func (c *Cache) AllEV(ns, db, tb string) (out []*sql.DefineEventStatement, err e out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -625,15 +648,14 @@ func (c *Cache) AllEV(ns, db, tb string) (out []*sql.DefineEventStatement, err e func (c *Cache) GetEV(ns, db, tb, ev string) (val *sql.DefineEventStatement, err error) { - idx := (&keys.EV{NS: ns, DB: db, TB: tb, EV: ev}).String() - - if out, ok := c.get(idx); ok { - return out.(*sql.DefineEventStatement), nil - } - var kv kvs.KV key := &keys.EV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, EV: ev} + + if out, ok := c.get(key); ok { + return out.(*sql.DefineEventStatement), nil + } + if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -645,7 +667,7 @@ func (c *Cache) GetEV(ns, db, tb, ev string) (val *sql.DefineEventStatement, err val = &sql.DefineEventStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return @@ -653,11 +675,9 @@ func (c *Cache) GetEV(ns, db, tb, ev string) (val *sql.DefineEventStatement, err func (c *Cache) DelEV(ns, db, tb, ev string) { - c.del((&keys.EV{NS: ns, DB: db, TB: tb, EV: keys.Ignore}).String()) + c.del(&keys.EV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, EV: keys.Ignore}) - c.del((&keys.EV{NS: ns, DB: db, TB: tb, EV: ev}).String()) - - return + c.del(&keys.EV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, EV: ev}) } @@ -665,15 +685,14 @@ func (c *Cache) DelEV(ns, db, tb, ev string) { func (c *Cache) AllFD(ns, db, tb string) (out []*sql.DefineFieldStatement, err error) { - idx := (&keys.FD{NS: ns, DB: db, TB: tb, FD: keys.Ignore}).String() - - if out, ok := c.get(idx); ok { - return out.([]*sql.DefineFieldStatement), nil - } - var kvs []kvs.KV key := &keys.FD{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FD: keys.Ignore} + + if out, ok := c.get(key); ok { + return out.([]*sql.DefineFieldStatement), nil + } + if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -684,7 +703,7 @@ func (c *Cache) AllFD(ns, db, tb string) (out []*sql.DefineFieldStatement, err e out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -692,15 +711,14 @@ func (c *Cache) AllFD(ns, db, tb string) (out []*sql.DefineFieldStatement, err e func (c *Cache) GetFD(ns, db, tb, fd string) (val *sql.DefineFieldStatement, err error) { - idx := (&keys.FD{NS: ns, DB: db, TB: tb, FD: fd}).String() - - if out, ok := c.get(idx); ok { - return out.(*sql.DefineFieldStatement), nil - } - var kv kvs.KV key := &keys.FD{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FD: fd} + + if out, ok := c.get(key); ok { + return out.(*sql.DefineFieldStatement), nil + } + if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -712,7 +730,7 @@ func (c *Cache) GetFD(ns, db, tb, fd string) (val *sql.DefineFieldStatement, err val = &sql.DefineFieldStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return @@ -720,11 +738,9 @@ func (c *Cache) GetFD(ns, db, tb, fd string) (val *sql.DefineFieldStatement, err func (c *Cache) DelFD(ns, db, tb, fd string) { - c.del((&keys.FD{NS: ns, DB: db, TB: tb, FD: keys.Ignore}).String()) + c.del(&keys.FD{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FD: keys.Ignore}) - c.del((&keys.FD{NS: ns, DB: db, TB: tb, FD: fd}).String()) - - return + c.del(&keys.FD{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FD: fd}) } @@ -732,15 +748,14 @@ func (c *Cache) DelFD(ns, db, tb, fd string) { func (c *Cache) AllIX(ns, db, tb string) (out []*sql.DefineIndexStatement, err error) { - idx := (&keys.IX{NS: ns, DB: db, TB: tb, IX: keys.Ignore}).String() - - if out, ok := c.get(idx); ok { - return out.([]*sql.DefineIndexStatement), nil - } - var kvs []kvs.KV key := &keys.IX{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, IX: keys.Ignore} + + if out, ok := c.get(key); ok { + return out.([]*sql.DefineIndexStatement), nil + } + if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -751,7 +766,7 @@ func (c *Cache) AllIX(ns, db, tb string) (out []*sql.DefineIndexStatement, err e out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -759,15 +774,14 @@ func (c *Cache) AllIX(ns, db, tb string) (out []*sql.DefineIndexStatement, err e func (c *Cache) GetIX(ns, db, tb, ix string) (val *sql.DefineIndexStatement, err error) { - idx := (&keys.IX{NS: ns, DB: db, TB: tb, IX: ix}).String() - - if out, ok := c.get(idx); ok { - return out.(*sql.DefineIndexStatement), nil - } - var kv kvs.KV key := &keys.IX{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, IX: ix} + + if out, ok := c.get(key); ok { + return out.(*sql.DefineIndexStatement), nil + } + if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -779,7 +793,7 @@ func (c *Cache) GetIX(ns, db, tb, ix string) (val *sql.DefineIndexStatement, err val = &sql.DefineIndexStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return @@ -787,11 +801,9 @@ func (c *Cache) GetIX(ns, db, tb, ix string) (val *sql.DefineIndexStatement, err func (c *Cache) DelIX(ns, db, tb, ix string) { - c.del((&keys.IX{NS: ns, DB: db, TB: tb, IX: keys.Ignore}).String()) + c.del(&keys.IX{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, IX: keys.Ignore}) - c.del((&keys.IX{NS: ns, DB: db, TB: tb, IX: ix}).String()) - - return + c.del(&keys.IX{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, IX: ix}) } @@ -799,15 +811,14 @@ func (c *Cache) DelIX(ns, db, tb, ix string) { func (c *Cache) AllFT(ns, db, tb string) (out []*sql.DefineTableStatement, err error) { - idx := (&keys.FT{NS: ns, DB: db, TB: tb, FT: keys.Ignore}).String() - - if out, ok := c.get(idx); ok { - return out.([]*sql.DefineTableStatement), nil - } - var kvs []kvs.KV key := &keys.FT{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FT: keys.Ignore} + + if out, ok := c.get(key); ok { + return out.([]*sql.DefineTableStatement), nil + } + if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -818,7 +829,7 @@ func (c *Cache) AllFT(ns, db, tb string) (out []*sql.DefineTableStatement, err e out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -826,15 +837,14 @@ func (c *Cache) AllFT(ns, db, tb string) (out []*sql.DefineTableStatement, err e func (c *Cache) GetFT(ns, db, tb, ft string) (val *sql.DefineTableStatement, err error) { - idx := (&keys.FT{NS: ns, DB: db, TB: tb, FT: ft}).String() - - if out, ok := c.get(idx); ok { - return out.(*sql.DefineTableStatement), nil - } - var kv kvs.KV key := &keys.FT{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FT: ft} + + if out, ok := c.get(key); ok { + return out.(*sql.DefineTableStatement), nil + } + if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -846,7 +856,7 @@ func (c *Cache) GetFT(ns, db, tb, ft string) (val *sql.DefineTableStatement, err val = &sql.DefineTableStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return @@ -854,11 +864,9 @@ func (c *Cache) GetFT(ns, db, tb, ft string) (val *sql.DefineTableStatement, err func (c *Cache) DelFT(ns, db, tb, ft string) { - c.del((&keys.FT{NS: ns, DB: db, TB: tb, FT: keys.Ignore}).String()) + c.del(&keys.FT{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FT: keys.Ignore}) - c.del((&keys.FT{NS: ns, DB: db, TB: tb, FT: ft}).String()) - - return + c.del(&keys.FT{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, FT: ft}) } @@ -866,15 +874,14 @@ func (c *Cache) DelFT(ns, db, tb, ft string) { func (c *Cache) AllLV(ns, db, tb string) (out []*sql.LiveStatement, err error) { - idx := (&keys.LV{NS: ns, DB: db, TB: tb, LV: keys.Ignore}).String() - - if out, ok := c.get(idx); ok { - return out.([]*sql.LiveStatement), nil - } - var kvs []kvs.KV key := &keys.LV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, LV: keys.Ignore} + + if out, ok := c.get(key); ok { + return out.([]*sql.LiveStatement), nil + } + if kvs, err = c.TX.GetP(0, key.Encode(), 0); err != nil { return } @@ -885,7 +892,7 @@ func (c *Cache) AllLV(ns, db, tb string) (out []*sql.LiveStatement, err error) { out = append(out, val) } - c.put(idx, out) + c.put(key, out) return @@ -893,15 +900,14 @@ func (c *Cache) AllLV(ns, db, tb string) (out []*sql.LiveStatement, err error) { func (c *Cache) GetLV(ns, db, tb, lv string) (val *sql.LiveStatement, err error) { - idx := (&keys.LV{NS: ns, DB: db, TB: tb, LV: lv}).String() - - if out, ok := c.get(idx); ok { - return out.(*sql.LiveStatement), nil - } - var kv kvs.KV key := &keys.LV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, LV: lv} + + if out, ok := c.get(key); ok { + return out.(*sql.LiveStatement), nil + } + if kv, err = c.TX.Get(0, key.Encode()); err != nil { return nil, err } @@ -913,7 +919,7 @@ func (c *Cache) GetLV(ns, db, tb, lv string) (val *sql.LiveStatement, err error) val = &sql.LiveStatement{} val.Decode(kv.Val()) - c.put(idx, val) + c.put(key, val) return @@ -921,10 +927,8 @@ func (c *Cache) GetLV(ns, db, tb, lv string) (val *sql.LiveStatement, err error) func (c *Cache) DelLV(ns, db, tb, lv string) { - c.del((&keys.LV{NS: ns, DB: db, TB: tb, LV: keys.Ignore}).String()) + c.del(&keys.LV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, LV: keys.Ignore}) - c.del((&keys.LV{NS: ns, DB: db, TB: tb, LV: lv}).String()) - - return + c.del(&keys.LV{KV: cnf.Settings.DB.Base, NS: ns, DB: db, TB: tb, LV: lv}) }