diff --git a/db/check.go b/db/check.go index a63979bd..53b31d48 100644 --- a/db/check.go +++ b/db/check.go @@ -48,7 +48,7 @@ func (d *document) check(ctx context.Context, cond sql.Expr) (ok bool, err error // Grant checks to see if the table permissions allow // this record to be accessed for live queries, and // if not then it errors accordingly. -func (d *document) grant(ctx context.Context, when method) (ok bool, err error) { +func (d *document) grant(ctx context.Context, met method) (ok bool, err error) { var val interface{} @@ -86,16 +86,8 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error) // for this table, then because this is // a scoped request, return an error. - switch p := tb.Perms.(type) { - case *sql.PermExpression: - switch when { - case _CREATE: - val, err = d.i.e.fetch(ctx, p.Select, d.current) - case _UPDATE: - val, err = d.i.e.fetch(ctx, p.Select, d.current) - case _DELETE: - val, err = d.i.e.fetch(ctx, p.Select, d.initial) - } + if p, ok := tb.Perms.(*sql.PermExpression); ok { + val, err = d.i.e.fetch(ctx, p.Select, d.current) } // If the permissions expressions @@ -103,8 +95,8 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error) // return this, dictating whether the // document is able to be viewed. - if val, ok := val.(bool); ok { - return val, err + if v, ok := val.(bool); ok { + return v, err } // Otherwise as this request is scoped, @@ -118,7 +110,7 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error) // Query checks to see if the table permissions allow // this record to be accessed for normal queries, and // if not then it errors accordingly. -func (d *document) allow(ctx context.Context, when method) (ok bool, err error) { +func (d *document) allow(ctx context.Context, met method) (ok bool, err error) { var val interface{} @@ -156,9 +148,8 @@ func (d *document) allow(ctx context.Context, when method) (ok bool, err error) // for this table, then because this is // a scoped request, return an error. - switch p := tb.Perms.(type) { - case *sql.PermExpression: - switch when { + if p, ok := tb.Perms.(*sql.PermExpression); ok { + switch met { case _SELECT: val, err = d.i.e.fetch(ctx, p.Select, d.current) case _CREATE: @@ -175,8 +166,8 @@ func (d *document) allow(ctx context.Context, when method) (ok bool, err error) // return this, dictating whether the // document is able to be viewed. - if val, ok := val.(bool); ok { - return val, err + if v, ok := val.(bool); ok { + return v, err } // Otherwise as this request is scoped, diff --git a/db/create.go b/db/create.go index 6d7b5ff8..a88a0d7c 100644 --- a/db/create.go +++ b/db/create.go @@ -114,7 +114,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int return nil, &ExistError{exist: d.id} } - if err = d.merge(ctx, stm.Data); err != nil { + if err = d.merge(ctx, met, stm.Data); err != nil { return nil, err } diff --git a/db/define_test.go b/db/define_test.go index 14ebfac7..e61d432a 100644 --- a/db/define_test.go +++ b/db/define_test.go @@ -212,6 +212,33 @@ func TestDefine(t *testing.T) { }) + Convey("Convert a schemaless to schemafull table, and ensure schemaless fields are still output", t, func() { + + setupDB() + + txt := ` + USE NS test DB test; + DEFINE TABLE person SCHEMALESS; + UPDATE person:test SET test=true, other="text"; + DEFINE TABLE person SCHEMAFULL; + DEFINE FIELD test ON person TYPE boolean; + SELECT * FROM person; + DEFINE FIELD other ON person TYPE string; + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 8) + So(data.Consume(res[2].Result[0]).Get("test").Data(), ShouldEqual, true) + So(data.Consume(res[2].Result[0]).Get("other").Data(), ShouldEqual, "text") + So(data.Consume(res[5].Result[0]).Get("test").Data(), ShouldEqual, true) + So(data.Consume(res[5].Result[0]).Get("other").Data(), ShouldEqual, "text") + So(data.Consume(res[7].Result[0]).Get("test").Data(), ShouldEqual, true) + So(data.Consume(res[7].Result[0]).Get("other").Data(), ShouldEqual, "text") + + }) + Convey("Define a drop table", t, func() { setupDB() @@ -372,6 +399,101 @@ func TestDefine(t *testing.T) { }) + Convey("Specify the permissions of a field so that it is only visible to the correct authentication levels", t, func() { + + setupDB() + + func() { + + txt := ` + USE NS test DB test; + DEFINE TABLE person PERMISSIONS FULL; + DEFINE FIELD name ON person PERMISSIONS FULL; + DEFINE FIELD pass ON person PERMISSIONS NONE; + DEFINE FIELD test ON person PERMISSIONS FOR CREATE, UPDATE FULL FOR SELECT NONE; + DEFINE FIELD temp ON person PERMISSIONS NONE; + DEFINE FIELD temp.test ON person PERMISSIONS FULL; + UPDATE person:test SET name="Tobias", pass="qhmyjahdc4", test="k5n87urq8l", temp.test="zw3wf5ls39"; + SELECT * FROM person; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 9) + So(res[7].Result, ShouldHaveLength, 1) + So(data.Consume(res[7].Result[0]).Get("name").Data(), ShouldEqual, "Tobias") + So(data.Consume(res[7].Result[0]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4") + So(data.Consume(res[7].Result[0]).Get("test").Data(), ShouldEqual, "k5n87urq8l") + So(data.Consume(res[7].Result[0]).Get("temp.test").Data(), ShouldEqual, "zw3wf5ls39") + So(res[8].Result, ShouldHaveLength, 1) + So(data.Consume(res[8].Result[0]).Get("name").Data(), ShouldEqual, "Tobias") + So(data.Consume(res[8].Result[0]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4") + So(data.Consume(res[8].Result[0]).Get("test").Data(), ShouldEqual, "k5n87urq8l") + So(data.Consume(res[8].Result[0]).Get("temp.test").Data(), ShouldEqual, "zw3wf5ls39") + + }() + + func() { + + txt := ` + USE NS test DB test; + CREATE person:1 SET name="Silvana", pass="1f65flhfvq", test="35aptguqoj", temp.test="h08ryx3519"; + UPDATE person:2 SET name="Jonathan", pass="8k796m5mmj", test="1lzdhd6wzg", temp.test="xurnxp8a1e"; + SELECT * FROM person ORDER BY name; + ` + + res, err := Execute(setupSC(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 4) + So(res[1].Result, ShouldHaveLength, 1) + So(data.Consume(res[1].Result[0]).Get("name").Data(), ShouldEqual, "Silvana") + So(data.Consume(res[1].Result[0]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[1].Result[0]).Get("test").Data(), ShouldEqual, nil) + So(data.Consume(res[1].Result[0]).Get("temp.test").Data(), ShouldEqual, nil) + So(res[2].Result, ShouldHaveLength, 1) + So(data.Consume(res[2].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan") + So(data.Consume(res[2].Result[0]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[2].Result[0]).Get("test").Data(), ShouldEqual, nil) + So(data.Consume(res[2].Result[0]).Get("temp.test").Data(), ShouldEqual, nil) + So(res[3].Result, ShouldHaveLength, 3) + So(data.Consume(res[3].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan") + So(data.Consume(res[3].Result[0]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[1]).Get("name").Data(), ShouldEqual, "Silvana") + So(data.Consume(res[3].Result[1]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[1]).Get("test").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[2]).Get("name").Data(), ShouldEqual, "Tobias") + So(data.Consume(res[3].Result[2]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[2]).Get("test").Data(), ShouldEqual, nil) + So(data.Consume(res[3].Result[2]).Get("temp.test").Data(), ShouldEqual, nil) + + }() + + func() { + + txt := ` + USE NS test DB test; + SELECT * FROM person ORDER BY name; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 2) + So(res[1].Result, ShouldHaveLength, 3) + So(data.Consume(res[1].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan") + So(data.Consume(res[1].Result[0]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[1].Result[0]).Get("test").Data(), ShouldEqual, "1lzdhd6wzg") + So(data.Consume(res[1].Result[1]).Get("name").Data(), ShouldEqual, "Silvana") + So(data.Consume(res[1].Result[1]).Get("pass").Data(), ShouldEqual, nil) + So(data.Consume(res[1].Result[1]).Get("test").Data(), ShouldEqual, "35aptguqoj") + So(data.Consume(res[1].Result[2]).Get("name").Data(), ShouldEqual, "Tobias") + So(data.Consume(res[1].Result[2]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4") + So(data.Consume(res[1].Result[2]).Get("test").Data(), ShouldEqual, "k5n87urq8l") + + }() + + }) + Convey("Define an event when a value changes", t, func() { setupDB() diff --git a/db/document.go b/db/document.go index e480972b..3535632e 100644 --- a/db/document.go +++ b/db/document.go @@ -282,15 +282,6 @@ func (d *document) changed() bool { return len(c) > 0 } -func (d *document) diff() *data.Doc { - a, _ := d.initial.Data().(map[string]interface{}) - b, _ := d.current.Data().(map[string]interface{}) - if c := diff.Diff(a, b); len(c) > 0 { - return data.Consume(c) - } - return data.Consume(nil) -} - func (d *document) shouldDrop() (bool, error) { // Check whether it is specified diff --git a/db/event.go b/db/event.go index ed1f139e..c857679b 100644 --- a/db/event.go +++ b/db/event.go @@ -23,7 +23,7 @@ import ( // Event checks if any triggers are specified for this // table, and executes them in name order. -func (d *document) event(ctx context.Context, when method) (err error) { +func (d *document) event(ctx context.Context, met method) (err error) { // Get the event values specified // for this table, loop through @@ -38,7 +38,7 @@ func (d *document) event(ctx context.Context, when method) (err error) { kind := "" - switch when { + switch met { case _CREATE: kind = "CREATE" case _UPDATE: diff --git a/db/insert.go b/db/insert.go index 5aec28d9..c58d089b 100644 --- a/db/insert.go +++ b/db/insert.go @@ -92,7 +92,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int return nil, &ExistError{exist: d.id} } - if err = d.merge(ctx, nil); err != nil { + if err = d.merge(ctx, met, nil); err != nil { return nil, err } diff --git a/db/lives.go b/db/lives.go index 641feefd..50f9f81c 100644 --- a/db/lives.go +++ b/db/lives.go @@ -18,7 +18,6 @@ import ( "context" "github.com/abcum/surreal/sql" - "github.com/abcum/surreal/util/data" ) // Lives checks if any table views are specified for @@ -54,7 +53,7 @@ func (d *document) lives(ctx context.Context, when method) (err error) { var ok bool var con *socket - var doc *data.Doc + var out interface{} if con, ok = sockets[lv.FB]; ok { @@ -98,7 +97,7 @@ func (d *document) lives(ctx context.Context, when method) (err error) { case true: - doc = d.diff() + out, _ = d.yield(ctx, lv, sql.DIFF) // If the query has projected fields which it // wants to receive, then let's fetch these @@ -106,39 +105,21 @@ func (d *document) lives(ctx context.Context, when method) (err error) { case false: - for _, v := range lv.Expr { - if _, ok := v.Expr.(*sql.All); ok { - doc = d.current - break - } - } - - if doc == nil { - doc = data.New() - } - - for _, e := range lv.Expr { - switch v := e.Expr.(type) { - case *sql.All: - break - default: - v, err := d.i.e.fetch(ctx, v, d.current) - if err != nil { - continue - } - doc.Set(v, e.Field) - } - } + out, _ = d.yield(ctx, lv, sql.ILLEGAL) } switch when { - case _CREATE: - con.queue(id, lv.ID, "CREATE", doc.Data()) - case _UPDATE: - con.queue(id, lv.ID, "UPDATE", doc.Data()) case _DELETE: con.queue(id, lv.ID, "DELETE", d.id) + case _CREATE: + if out != nil { + con.queue(id, lv.ID, "CREATE", out) + } + case _UPDATE: + if out != nil { + con.queue(id, lv.ID, "UPDATE", out) + } } } diff --git a/db/merge.go b/db/merge.go index 8feef851..0712d0fd 100644 --- a/db/merge.go +++ b/db/merge.go @@ -19,6 +19,7 @@ import ( "context" + "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/conv" "github.com/abcum/surreal/util/data" @@ -32,44 +33,44 @@ var main = map[string]struct{}{ "meta.id": {}, } -func (d *document) merge(ctx context.Context, data sql.Expr) (err error) { +func (d *document) merge(ctx context.Context, met method, data sql.Expr) (err error) { - if err = d.defFld(ctx); err != nil { + if err = d.defFld(ctx, met); err != nil { return } switch expr := data.(type) { case *sql.DataExpression: - if err = d.mrgSet(ctx, expr); err != nil { + if err = d.mrgSet(ctx, met, expr); err != nil { return err } case *sql.DiffExpression: - if err = d.mrgDpm(ctx, expr); err != nil { + if err = d.mrgDpm(ctx, met, expr); err != nil { return err } case *sql.MergeExpression: - if err = d.mrgAny(ctx, expr); err != nil { + if err = d.mrgAny(ctx, met, expr); err != nil { return err } case *sql.ContentExpression: - if err = d.mrgAll(ctx, expr); err != nil { + if err = d.mrgAll(ctx, met, expr); err != nil { return err } } - if err = d.defFld(ctx); err != nil { + if err = d.defFld(ctx, met); err != nil { return } - if err = d.mrgFld(ctx); err != nil { + if err = d.mrgFld(ctx, met); err != nil { return } - if err = d.defFld(ctx); err != nil { + if err = d.defFld(ctx, met); err != nil { return } - if err = d.delFld(ctx); err != nil { + if err = d.delFld(ctx, met); err != nil { return } @@ -77,7 +78,7 @@ func (d *document) merge(ctx context.Context, data sql.Expr) (err error) { } -func (d *document) defFld(ctx context.Context) (err error) { +func (d *document) defFld(ctx context.Context, met method) (err error) { d.current.Set(d.id, "id") d.current.Set(d.md, "meta") @@ -86,7 +87,7 @@ func (d *document) defFld(ctx context.Context) (err error) { } -func (d *document) delFld(ctx context.Context) (err error) { +func (d *document) delFld(ctx context.Context, met method) (err error) { tb, err := d.getTB() if err != nil { @@ -130,7 +131,7 @@ func (d *document) delFld(ctx context.Context) (err error) { } -func (d *document) mrgAll(ctx context.Context, expr *sql.ContentExpression) (err error) { +func (d *document) mrgAll(ctx context.Context, met method, expr *sql.ContentExpression) (err error) { var obj map[string]interface{} @@ -161,7 +162,7 @@ func (d *document) mrgAll(ctx context.Context, expr *sql.ContentExpression) (err } -func (d *document) mrgAny(ctx context.Context, expr *sql.MergeExpression) (err error) { +func (d *document) mrgAny(ctx context.Context, met method, expr *sql.MergeExpression) (err error) { var obj map[string]interface{} @@ -190,7 +191,7 @@ func (d *document) mrgAny(ctx context.Context, expr *sql.MergeExpression) (err e } -func (d *document) mrgDpm(ctx context.Context, expr *sql.DiffExpression) (err error) { +func (d *document) mrgDpm(ctx context.Context, met method, expr *sql.DiffExpression) (err error) { var obj []interface{} var old map[string]interface{} @@ -222,7 +223,7 @@ func (d *document) mrgDpm(ctx context.Context, expr *sql.DiffExpression) (err er } -func (d *document) mrgSet(ctx context.Context, expr *sql.DataExpression) (err error) { +func (d *document) mrgSet(ctx context.Context, met method, expr *sql.DataExpression) (err error) { for _, v := range expr.Data { @@ -255,7 +256,7 @@ func (d *document) mrgSet(ctx context.Context, expr *sql.DataExpression) (err er } -func (d *document) mrgFld(ctx context.Context) (err error) { +func (d *document) mrgFld(ctx context.Context, met method) (err error) { fds, err := d.getFD() if err != nil { @@ -306,40 +307,95 @@ func (d *document) mrgFld(ctx context.Context) (err error) { } } - // Reset the variables - - vars.Set(val, varKeyValue) - vars.Set(val, varKeyAfter) - vars.Set(old, varKeyBefore) - ctx = context.WithValue(ctx, ctxKeySpec, vars) - // We are setting the value of the field if fd.Value != nil { + + // Reset the variables + + vars.Set(val, varKeyValue) + vars.Set(val, varKeyAfter) + vars.Set(old, varKeyBefore) + ctx = context.WithValue(ctx, ctxKeySpec, vars) + if now, err := d.i.e.fetch(ctx, fd.Value, d.current); err != nil { return err } else { val = now } + } - // Reset the variables - - vars.Set(val, varKeyValue) - vars.Set(val, varKeyAfter) - vars.Set(old, varKeyBefore) - ctx = context.WithValue(ctx, ctxKeySpec, vars) - // We are checking the value of the field if fd.Assert != nil { + + // Reset the variables + + vars.Set(val, varKeyValue) + vars.Set(val, varKeyAfter) + vars.Set(old, varKeyBefore) + ctx = context.WithValue(ctx, ctxKeySpec, vars) + if chk, err := d.i.e.fetch(ctx, fd.Assert, d.current); err != nil { return err } else if chk, ok := chk.(bool); ok && !chk { return &FieldError{field: key, found: val, check: fd.Assert} } + } + // We are checking the permissions of the field + + if fd.Perms != nil { + + if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok { + if k > cnf.AuthDB { + + // Reset the variables + + vars.Set(val, varKeyValue) + vars.Set(val, varKeyAfter) + vars.Set(old, varKeyBefore) + ctx = context.WithValue(ctx, ctxKeySpec, vars) + + switch p := fd.Perms.(type) { + case *sql.PermExpression: + switch met { + case _CREATE: + if v, err := d.i.e.fetch(ctx, p.Create, d.current); err != nil { + return err + } else { + if b, ok := v.(bool); !ok || !b { + val = old + } + } + case _UPDATE: + if v, err := d.i.e.fetch(ctx, p.Update, d.current); err != nil { + return err + } else { + if b, ok := v.(bool); !ok || !b { + val = old + } + } + case _DELETE: + if v, err := d.i.e.fetch(ctx, p.Delete, d.current); err != nil { + return err + } else { + if b, ok := v.(bool); !ok || !b { + val = old + } + } + } + } + + } + } + + } + + // We are setting the value of the field + switch val.(type) { default: d.current.Iff(val, key) diff --git a/db/perms.go b/db/perms.go new file mode 100644 index 00000000..4e6ed3b6 --- /dev/null +++ b/db/perms.go @@ -0,0 +1,90 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" + + "github.com/abcum/surreal/cnf" + "github.com/abcum/surreal/sql" + "github.com/abcum/surreal/util/data" +) + +func (d *document) perms(ctx context.Context, doc *data.Doc) (err error) { + + // Get the field definitions so we can + // check if the permissions allow us + // to view each field. + + fds, err := d.getFD() + if err != nil { + return err + } + + // Once we have the table we reset the + // context to DB level so that no other + // embedded permissions are checked on + // records within these permissions. + + ctx = context.WithValue(ctx, ctxKeyKind, cnf.AuthDB) + + // We then try to process the relevant + // permissions dependent on the query + // that we are currently processing. If + // there are no permissions specified + // for this table, then because this is + // a scoped request, return an error. + + for _, fd := range fds { + + if fd.Perms != nil { + + err = doc.Walk(func(key string, val interface{}) error { + + // We are checking the permissions of the field + + if p, ok := fd.Perms.(*sql.PermExpression); ok { + + // Get the old value + + old := d.initial.Get(key).Data() + + // Reset the variables + + vars := data.New() + vars.Set(val, varKeyValue) + vars.Set(val, varKeyAfter) + vars.Set(old, varKeyBefore) + ctx = context.WithValue(ctx, ctxKeySpec, vars) + + if v, err := d.i.e.fetch(ctx, p.Select, doc); err != nil { + return err + } else if b, ok := v.(bool); !ok || !b { + doc.Del(key) + } + + } + + return nil + + }, fd.Name.ID) + + } + + } + + return nil + +} diff --git a/db/relate.go b/db/relate.go index 47fa4b63..86ea0f08 100644 --- a/db/relate.go +++ b/db/relate.go @@ -106,7 +106,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int met = _UPDATE } - if err = d.merge(ctx, stm.Data); err != nil { + if err = d.merge(ctx, met, stm.Data); err != nil { return nil, err } diff --git a/db/select.go b/db/select.go index 918af588..7f23b5a6 100644 --- a/db/select.go +++ b/db/select.go @@ -137,6 +137,7 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int var ok bool var err error + var met = _SELECT defer d.close() @@ -148,12 +149,10 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int return nil, nil } - if d.doc == nil { - if ok, err = d.allow(ctx, _SELECT); err != nil { - return nil, err - } else if ok == false { - return nil, nil - } + if ok, err = d.allow(ctx, met); err != nil { + return nil, err + } else if ok == false { + return nil, nil } if ok, err = d.check(ctx, stm.Cond); err != nil { diff --git a/db/table.go b/db/table.go index a16c7d83..3af792cc 100644 --- a/db/table.go +++ b/db/table.go @@ -19,7 +19,6 @@ import ( "context" - "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/keys" ) @@ -135,7 +134,7 @@ func (d *document) table(ctx context.Context, when method) (err error) { func (d *document) tableDelete(ctx context.Context, id *sql.Thing, exp sql.Fields) (err error) { stm := &sql.DeleteStatement{ - KV: cnf.Settings.DB.Base, + KV: d.key.KV, NS: d.key.NS, DB: d.key.DB, What: sql.Exprs{id}, @@ -163,7 +162,7 @@ func (d *document) tableUpdate(ctx context.Context, id *sql.Thing, exp sql.Field } stm := &sql.UpdateStatement{ - KV: cnf.Settings.DB.Base, + KV: d.key.KV, NS: d.key.NS, DB: d.key.DB, What: sql.Exprs{id}, diff --git a/db/update.go b/db/update.go index be6baf09..71e3b525 100644 --- a/db/update.go +++ b/db/update.go @@ -125,7 +125,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int return nil, nil } - if err = d.merge(ctx, stm.Data); err != nil { + if err = d.merge(ctx, met, stm.Data); err != nil { return nil, err } diff --git a/db/upsert.go b/db/upsert.go index 7e802178..f75e151c 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -98,7 +98,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int return nil, nil } - if err = d.merge(ctx, nil); err != nil { + if err = d.merge(ctx, met, nil); err != nil { return nil, err } diff --git a/db/yield.go b/db/yield.go index 9377efd7..9f922aef 100644 --- a/db/yield.go +++ b/db/yield.go @@ -17,92 +17,209 @@ package db import ( "context" + "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/data" + "github.com/abcum/surreal/util/diff" ) +func (d *document) cold(ctx context.Context) (doc *data.Doc, err error) { + + // If we are authenticated using DB, NS, + // or KV permissions level, then we can + // return the document without copying. + + if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok { + if k < cnf.AuthSC { + return d.initial, nil + } + } + + // Otherwise, we need to create a copy + // of the document so that we can add + // and remove fields before outputting. + + doc = d.initial.Copy() + + err = d.perms(ctx, doc) + + return + +} + +func (d *document) cnow(ctx context.Context) (doc *data.Doc, err error) { + + // If we are authenticated using DB, NS, + // or KV permissions level, then we can + // return the document without copying. + + if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok { + if k < cnf.AuthSC { + return d.current, nil + } + } + + // Otherwise, we need to create a copy + // of the document so that we can add + // and remove fields before outputting. + + doc = d.current.Copy() + + err = d.perms(ctx, doc) + + return + +} + +func (d *document) diffs(initial, current *data.Doc) *data.Doc { + + a, _ := initial.Data().(map[string]interface{}) + b, _ := current.Data().(map[string]interface{}) + + if c := diff.Diff(a, b); len(c) > 0 { + return data.Consume(c) + } + + return data.Consume(nil) + +} + func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Token) (interface{}, error) { + var exps sql.Fields + var grps sql.Groups + switch stm := stm.(type) { - + case *sql.LiveStatement: + exps = stm.Expr case *sql.SelectStatement: + exps = stm.Expr + grps = stm.Group + } - var doc *data.Doc + // If there are no field expressions + // then this was not a LIVE or SELECT + // query, and therefore the query will + // have an output format specified. - for _, v := range stm.Expr { - if _, ok := v.Expr.(*sql.All); ok { - doc = d.current - break - } - } - - if doc == nil { - doc = data.New() - } - - for _, e := range stm.Expr { - - switch v := e.Expr.(type) { - case *sql.All: - break - default: - - // If the query has a GROUP BY expression - // then let's check if this is an aggregate - // function, and if it is then pass the - // first argument directly through. - - if len(stm.Group) > 0 { - if f, ok := e.Expr.(*sql.FuncExpression); ok && f.Aggr { - v, err := d.i.e.fetch(ctx, f.Args[0], d.current) - if err != nil { - return nil, err - } - doc.Set(v, f.String()) - continue - } - } - - // Otherwise treat the field normally, and - // calculate the value to be inserted into - // the final output document. - - v, err := d.i.e.fetch(ctx, v, d.current) - if err != nil { - return nil, err - } - - switch v { - case d.current: - doc.Set(nil, e.Field) - default: - doc.Set(v, e.Field) - } - - } - - } - - return doc.Data(), nil - - default: + if len(exps) == 0 { switch output { default: return nil, nil case sql.DIFF: - return d.diff().Data(), nil + + old, err := d.cold(ctx) + if err != nil { + return nil, err + } + + now, err := d.cnow(ctx) + if err != nil { + return nil, err + } + + return d.diffs(old, now).Data(), nil + case sql.AFTER: - return d.current.Data(), nil + + doc, err := d.cnow(ctx) + if err != nil { + return nil, err + } + return doc.Data(), nil + case sql.BEFORE: - return d.initial.Data(), nil + + doc, err := d.cold(ctx) + if err != nil { + return nil, err + } + return doc.Data(), nil + case sql.BOTH: + + old, err := d.cold(ctx) + if err != nil { + return nil, err + } + + now, err := d.cnow(ctx) + if err != nil { + return nil, err + } + return map[string]interface{}{ - "after": d.current.Data(), - "before": d.initial.Data(), + "after": now.Data(), + "before": old.Data(), }, nil + } } + // But if there are field expresions + // then this query is a LIVE or SELECT + // query, and we must output only the + // desired fields in the output. + + var out = data.New() + + doc, err := d.cnow(ctx) + if err != nil { + return nil, err + } + + for _, e := range exps { + if _, ok := e.Expr.(*sql.All); ok { + out = doc + break + } + } + + for _, e := range exps { + + switch v := e.Expr.(type) { + case *sql.All: + break + default: + + // If the query has a GROUP BY expression + // then let's check if this is an aggregate + // function, and if it is then pass the + // first argument directly through. + + if len(grps) > 0 { + if f, ok := e.Expr.(*sql.FuncExpression); ok && f.Aggr { + v, err := d.i.e.fetch(ctx, f.Args[0], doc) + if err != nil { + return nil, err + } + out.Set(v, f.String()) + continue + } + } + + // Otherwise treat the field normally, and + // calculate the value to be inserted into + // the final output document. + + v, err := d.i.e.fetch(ctx, v, doc) + if err != nil { + return nil, err + } + + switch v { + case doc: + out.Set(nil, e.Field) + default: + out.Set(v, e.Field) + } + + } + + } + + return out.Data(), nil + }