Enable permissions on individual document fields

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-14 19:14:47 +01:00
parent 2a74759a71
commit 67cfca04b9
15 changed files with 515 additions and 169 deletions

View file

@ -48,7 +48,7 @@ func (d *document) check(ctx context.Context, cond sql.Expr) (ok bool, err error
// Grant checks to see if the table permissions allow
// this record to be accessed for live queries, and
// if not then it errors accordingly.
func (d *document) grant(ctx context.Context, when method) (ok bool, err error) {
func (d *document) grant(ctx context.Context, met method) (ok bool, err error) {
var val interface{}
@ -86,16 +86,8 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error)
// for this table, then because this is
// a scoped request, return an error.
switch p := tb.Perms.(type) {
case *sql.PermExpression:
switch when {
case _CREATE:
val, err = d.i.e.fetch(ctx, p.Select, d.current)
case _UPDATE:
val, err = d.i.e.fetch(ctx, p.Select, d.current)
case _DELETE:
val, err = d.i.e.fetch(ctx, p.Select, d.initial)
}
if p, ok := tb.Perms.(*sql.PermExpression); ok {
val, err = d.i.e.fetch(ctx, p.Select, d.current)
}
// If the permissions expressions
@ -103,8 +95,8 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error)
// return this, dictating whether the
// document is able to be viewed.
if val, ok := val.(bool); ok {
return val, err
if v, ok := val.(bool); ok {
return v, err
}
// Otherwise as this request is scoped,
@ -118,7 +110,7 @@ func (d *document) grant(ctx context.Context, when method) (ok bool, err error)
// Query checks to see if the table permissions allow
// this record to be accessed for normal queries, and
// if not then it errors accordingly.
func (d *document) allow(ctx context.Context, when method) (ok bool, err error) {
func (d *document) allow(ctx context.Context, met method) (ok bool, err error) {
var val interface{}
@ -156,9 +148,8 @@ func (d *document) allow(ctx context.Context, when method) (ok bool, err error)
// for this table, then because this is
// a scoped request, return an error.
switch p := tb.Perms.(type) {
case *sql.PermExpression:
switch when {
if p, ok := tb.Perms.(*sql.PermExpression); ok {
switch met {
case _SELECT:
val, err = d.i.e.fetch(ctx, p.Select, d.current)
case _CREATE:
@ -175,8 +166,8 @@ func (d *document) allow(ctx context.Context, when method) (ok bool, err error)
// return this, dictating whether the
// document is able to be viewed.
if val, ok := val.(bool); ok {
return val, err
if v, ok := val.(bool); ok {
return v, err
}
// Otherwise as this request is scoped,

View file

@ -114,7 +114,7 @@ func (d *document) runCreate(ctx context.Context, stm *sql.CreateStatement) (int
return nil, &ExistError{exist: d.id}
}
if err = d.merge(ctx, stm.Data); err != nil {
if err = d.merge(ctx, met, stm.Data); err != nil {
return nil, err
}

View file

@ -212,6 +212,33 @@ func TestDefine(t *testing.T) {
})
Convey("Convert a schemaless to schemafull table, and ensure schemaless fields are still output", t, func() {
setupDB()
txt := `
USE NS test DB test;
DEFINE TABLE person SCHEMALESS;
UPDATE person:test SET test=true, other="text";
DEFINE TABLE person SCHEMAFULL;
DEFINE FIELD test ON person TYPE boolean;
SELECT * FROM person;
DEFINE FIELD other ON person TYPE string;
SELECT * FROM person;
`
res, err := Execute(setupKV(), txt, nil)
So(err, ShouldBeNil)
So(res, ShouldHaveLength, 8)
So(data.Consume(res[2].Result[0]).Get("test").Data(), ShouldEqual, true)
So(data.Consume(res[2].Result[0]).Get("other").Data(), ShouldEqual, "text")
So(data.Consume(res[5].Result[0]).Get("test").Data(), ShouldEqual, true)
So(data.Consume(res[5].Result[0]).Get("other").Data(), ShouldEqual, "text")
So(data.Consume(res[7].Result[0]).Get("test").Data(), ShouldEqual, true)
So(data.Consume(res[7].Result[0]).Get("other").Data(), ShouldEqual, "text")
})
Convey("Define a drop table", t, func() {
setupDB()
@ -372,6 +399,101 @@ func TestDefine(t *testing.T) {
})
Convey("Specify the permissions of a field so that it is only visible to the correct authentication levels", t, func() {
setupDB()
func() {
txt := `
USE NS test DB test;
DEFINE TABLE person PERMISSIONS FULL;
DEFINE FIELD name ON person PERMISSIONS FULL;
DEFINE FIELD pass ON person PERMISSIONS NONE;
DEFINE FIELD test ON person PERMISSIONS FOR CREATE, UPDATE FULL FOR SELECT NONE;
DEFINE FIELD temp ON person PERMISSIONS NONE;
DEFINE FIELD temp.test ON person PERMISSIONS FULL;
UPDATE person:test SET name="Tobias", pass="qhmyjahdc4", test="k5n87urq8l", temp.test="zw3wf5ls39";
SELECT * FROM person;
`
res, err := Execute(setupKV(), txt, nil)
So(err, ShouldBeNil)
So(res, ShouldHaveLength, 9)
So(res[7].Result, ShouldHaveLength, 1)
So(data.Consume(res[7].Result[0]).Get("name").Data(), ShouldEqual, "Tobias")
So(data.Consume(res[7].Result[0]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4")
So(data.Consume(res[7].Result[0]).Get("test").Data(), ShouldEqual, "k5n87urq8l")
So(data.Consume(res[7].Result[0]).Get("temp.test").Data(), ShouldEqual, "zw3wf5ls39")
So(res[8].Result, ShouldHaveLength, 1)
So(data.Consume(res[8].Result[0]).Get("name").Data(), ShouldEqual, "Tobias")
So(data.Consume(res[8].Result[0]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4")
So(data.Consume(res[8].Result[0]).Get("test").Data(), ShouldEqual, "k5n87urq8l")
So(data.Consume(res[8].Result[0]).Get("temp.test").Data(), ShouldEqual, "zw3wf5ls39")
}()
func() {
txt := `
USE NS test DB test;
CREATE person:1 SET name="Silvana", pass="1f65flhfvq", test="35aptguqoj", temp.test="h08ryx3519";
UPDATE person:2 SET name="Jonathan", pass="8k796m5mmj", test="1lzdhd6wzg", temp.test="xurnxp8a1e";
SELECT * FROM person ORDER BY name;
`
res, err := Execute(setupSC(), txt, nil)
So(err, ShouldBeNil)
So(res, ShouldHaveLength, 4)
So(res[1].Result, ShouldHaveLength, 1)
So(data.Consume(res[1].Result[0]).Get("name").Data(), ShouldEqual, "Silvana")
So(data.Consume(res[1].Result[0]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[1].Result[0]).Get("test").Data(), ShouldEqual, nil)
So(data.Consume(res[1].Result[0]).Get("temp.test").Data(), ShouldEqual, nil)
So(res[2].Result, ShouldHaveLength, 1)
So(data.Consume(res[2].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan")
So(data.Consume(res[2].Result[0]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[2].Result[0]).Get("test").Data(), ShouldEqual, nil)
So(data.Consume(res[2].Result[0]).Get("temp.test").Data(), ShouldEqual, nil)
So(res[3].Result, ShouldHaveLength, 3)
So(data.Consume(res[3].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan")
So(data.Consume(res[3].Result[0]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[0]).Get("test").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[1]).Get("name").Data(), ShouldEqual, "Silvana")
So(data.Consume(res[3].Result[1]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[1]).Get("test").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[2]).Get("name").Data(), ShouldEqual, "Tobias")
So(data.Consume(res[3].Result[2]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[2]).Get("test").Data(), ShouldEqual, nil)
So(data.Consume(res[3].Result[2]).Get("temp.test").Data(), ShouldEqual, nil)
}()
func() {
txt := `
USE NS test DB test;
SELECT * FROM person ORDER BY name;
`
res, err := Execute(setupKV(), txt, nil)
So(err, ShouldBeNil)
So(res, ShouldHaveLength, 2)
So(res[1].Result, ShouldHaveLength, 3)
So(data.Consume(res[1].Result[0]).Get("name").Data(), ShouldEqual, "Jonathan")
So(data.Consume(res[1].Result[0]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[1].Result[0]).Get("test").Data(), ShouldEqual, "1lzdhd6wzg")
So(data.Consume(res[1].Result[1]).Get("name").Data(), ShouldEqual, "Silvana")
So(data.Consume(res[1].Result[1]).Get("pass").Data(), ShouldEqual, nil)
So(data.Consume(res[1].Result[1]).Get("test").Data(), ShouldEqual, "35aptguqoj")
So(data.Consume(res[1].Result[2]).Get("name").Data(), ShouldEqual, "Tobias")
So(data.Consume(res[1].Result[2]).Get("pass").Data(), ShouldEqual, "qhmyjahdc4")
So(data.Consume(res[1].Result[2]).Get("test").Data(), ShouldEqual, "k5n87urq8l")
}()
})
Convey("Define an event when a value changes", t, func() {
setupDB()

View file

@ -282,15 +282,6 @@ func (d *document) changed() bool {
return len(c) > 0
}
func (d *document) diff() *data.Doc {
a, _ := d.initial.Data().(map[string]interface{})
b, _ := d.current.Data().(map[string]interface{})
if c := diff.Diff(a, b); len(c) > 0 {
return data.Consume(c)
}
return data.Consume(nil)
}
func (d *document) shouldDrop() (bool, error) {
// Check whether it is specified

View file

@ -23,7 +23,7 @@ import (
// Event checks if any triggers are specified for this
// table, and executes them in name order.
func (d *document) event(ctx context.Context, when method) (err error) {
func (d *document) event(ctx context.Context, met method) (err error) {
// Get the event values specified
// for this table, loop through
@ -38,7 +38,7 @@ func (d *document) event(ctx context.Context, when method) (err error) {
kind := ""
switch when {
switch met {
case _CREATE:
kind = "CREATE"
case _UPDATE:

View file

@ -92,7 +92,7 @@ func (d *document) runInsert(ctx context.Context, stm *sql.InsertStatement) (int
return nil, &ExistError{exist: d.id}
}
if err = d.merge(ctx, nil); err != nil {
if err = d.merge(ctx, met, nil); err != nil {
return nil, err
}

View file

@ -18,7 +18,6 @@ import (
"context"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/data"
)
// Lives checks if any table views are specified for
@ -54,7 +53,7 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
var ok bool
var con *socket
var doc *data.Doc
var out interface{}
if con, ok = sockets[lv.FB]; ok {
@ -98,7 +97,7 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
case true:
doc = d.diff()
out, _ = d.yield(ctx, lv, sql.DIFF)
// If the query has projected fields which it
// wants to receive, then let's fetch these
@ -106,39 +105,21 @@ func (d *document) lives(ctx context.Context, when method) (err error) {
case false:
for _, v := range lv.Expr {
if _, ok := v.Expr.(*sql.All); ok {
doc = d.current
break
}
}
if doc == nil {
doc = data.New()
}
for _, e := range lv.Expr {
switch v := e.Expr.(type) {
case *sql.All:
break
default:
v, err := d.i.e.fetch(ctx, v, d.current)
if err != nil {
continue
}
doc.Set(v, e.Field)
}
}
out, _ = d.yield(ctx, lv, sql.ILLEGAL)
}
switch when {
case _CREATE:
con.queue(id, lv.ID, "CREATE", doc.Data())
case _UPDATE:
con.queue(id, lv.ID, "UPDATE", doc.Data())
case _DELETE:
con.queue(id, lv.ID, "DELETE", d.id)
case _CREATE:
if out != nil {
con.queue(id, lv.ID, "CREATE", out)
}
case _UPDATE:
if out != nil {
con.queue(id, lv.ID, "UPDATE", out)
}
}
}

View file

@ -19,6 +19,7 @@ import (
"context"
"github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/conv"
"github.com/abcum/surreal/util/data"
@ -32,44 +33,44 @@ var main = map[string]struct{}{
"meta.id": {},
}
func (d *document) merge(ctx context.Context, data sql.Expr) (err error) {
func (d *document) merge(ctx context.Context, met method, data sql.Expr) (err error) {
if err = d.defFld(ctx); err != nil {
if err = d.defFld(ctx, met); err != nil {
return
}
switch expr := data.(type) {
case *sql.DataExpression:
if err = d.mrgSet(ctx, expr); err != nil {
if err = d.mrgSet(ctx, met, expr); err != nil {
return err
}
case *sql.DiffExpression:
if err = d.mrgDpm(ctx, expr); err != nil {
if err = d.mrgDpm(ctx, met, expr); err != nil {
return err
}
case *sql.MergeExpression:
if err = d.mrgAny(ctx, expr); err != nil {
if err = d.mrgAny(ctx, met, expr); err != nil {
return err
}
case *sql.ContentExpression:
if err = d.mrgAll(ctx, expr); err != nil {
if err = d.mrgAll(ctx, met, expr); err != nil {
return err
}
}
if err = d.defFld(ctx); err != nil {
if err = d.defFld(ctx, met); err != nil {
return
}
if err = d.mrgFld(ctx); err != nil {
if err = d.mrgFld(ctx, met); err != nil {
return
}
if err = d.defFld(ctx); err != nil {
if err = d.defFld(ctx, met); err != nil {
return
}
if err = d.delFld(ctx); err != nil {
if err = d.delFld(ctx, met); err != nil {
return
}
@ -77,7 +78,7 @@ func (d *document) merge(ctx context.Context, data sql.Expr) (err error) {
}
func (d *document) defFld(ctx context.Context) (err error) {
func (d *document) defFld(ctx context.Context, met method) (err error) {
d.current.Set(d.id, "id")
d.current.Set(d.md, "meta")
@ -86,7 +87,7 @@ func (d *document) defFld(ctx context.Context) (err error) {
}
func (d *document) delFld(ctx context.Context) (err error) {
func (d *document) delFld(ctx context.Context, met method) (err error) {
tb, err := d.getTB()
if err != nil {
@ -130,7 +131,7 @@ func (d *document) delFld(ctx context.Context) (err error) {
}
func (d *document) mrgAll(ctx context.Context, expr *sql.ContentExpression) (err error) {
func (d *document) mrgAll(ctx context.Context, met method, expr *sql.ContentExpression) (err error) {
var obj map[string]interface{}
@ -161,7 +162,7 @@ func (d *document) mrgAll(ctx context.Context, expr *sql.ContentExpression) (err
}
func (d *document) mrgAny(ctx context.Context, expr *sql.MergeExpression) (err error) {
func (d *document) mrgAny(ctx context.Context, met method, expr *sql.MergeExpression) (err error) {
var obj map[string]interface{}
@ -190,7 +191,7 @@ func (d *document) mrgAny(ctx context.Context, expr *sql.MergeExpression) (err e
}
func (d *document) mrgDpm(ctx context.Context, expr *sql.DiffExpression) (err error) {
func (d *document) mrgDpm(ctx context.Context, met method, expr *sql.DiffExpression) (err error) {
var obj []interface{}
var old map[string]interface{}
@ -222,7 +223,7 @@ func (d *document) mrgDpm(ctx context.Context, expr *sql.DiffExpression) (err er
}
func (d *document) mrgSet(ctx context.Context, expr *sql.DataExpression) (err error) {
func (d *document) mrgSet(ctx context.Context, met method, expr *sql.DataExpression) (err error) {
for _, v := range expr.Data {
@ -255,7 +256,7 @@ func (d *document) mrgSet(ctx context.Context, expr *sql.DataExpression) (err er
}
func (d *document) mrgFld(ctx context.Context) (err error) {
func (d *document) mrgFld(ctx context.Context, met method) (err error) {
fds, err := d.getFD()
if err != nil {
@ -306,40 +307,95 @@ func (d *document) mrgFld(ctx context.Context) (err error) {
}
}
// Reset the variables
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
// We are setting the value of the field
if fd.Value != nil {
// Reset the variables
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
if now, err := d.i.e.fetch(ctx, fd.Value, d.current); err != nil {
return err
} else {
val = now
}
}
// Reset the variables
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
// We are checking the value of the field
if fd.Assert != nil {
// Reset the variables
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
if chk, err := d.i.e.fetch(ctx, fd.Assert, d.current); err != nil {
return err
} else if chk, ok := chk.(bool); ok && !chk {
return &FieldError{field: key, found: val, check: fd.Assert}
}
}
// We are checking the permissions of the field
if fd.Perms != nil {
if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok {
if k > cnf.AuthDB {
// Reset the variables
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
switch p := fd.Perms.(type) {
case *sql.PermExpression:
switch met {
case _CREATE:
if v, err := d.i.e.fetch(ctx, p.Create, d.current); err != nil {
return err
} else {
if b, ok := v.(bool); !ok || !b {
val = old
}
}
case _UPDATE:
if v, err := d.i.e.fetch(ctx, p.Update, d.current); err != nil {
return err
} else {
if b, ok := v.(bool); !ok || !b {
val = old
}
}
case _DELETE:
if v, err := d.i.e.fetch(ctx, p.Delete, d.current); err != nil {
return err
} else {
if b, ok := v.(bool); !ok || !b {
val = old
}
}
}
}
}
}
}
// We are setting the value of the field
switch val.(type) {
default:
d.current.Iff(val, key)

90
db/perms.go Normal file
View file

@ -0,0 +1,90 @@
// Copyright © 2016 Abcum Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package db
import (
"context"
"github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/data"
)
func (d *document) perms(ctx context.Context, doc *data.Doc) (err error) {
// Get the field definitions so we can
// check if the permissions allow us
// to view each field.
fds, err := d.getFD()
if err != nil {
return err
}
// Once we have the table we reset the
// context to DB level so that no other
// embedded permissions are checked on
// records within these permissions.
ctx = context.WithValue(ctx, ctxKeyKind, cnf.AuthDB)
// We then try to process the relevant
// permissions dependent on the query
// that we are currently processing. If
// there are no permissions specified
// for this table, then because this is
// a scoped request, return an error.
for _, fd := range fds {
if fd.Perms != nil {
err = doc.Walk(func(key string, val interface{}) error {
// We are checking the permissions of the field
if p, ok := fd.Perms.(*sql.PermExpression); ok {
// Get the old value
old := d.initial.Get(key).Data()
// Reset the variables
vars := data.New()
vars.Set(val, varKeyValue)
vars.Set(val, varKeyAfter)
vars.Set(old, varKeyBefore)
ctx = context.WithValue(ctx, ctxKeySpec, vars)
if v, err := d.i.e.fetch(ctx, p.Select, doc); err != nil {
return err
} else if b, ok := v.(bool); !ok || !b {
doc.Del(key)
}
}
return nil
}, fd.Name.ID)
}
}
return nil
}

View file

@ -106,7 +106,7 @@ func (d *document) runRelate(ctx context.Context, stm *sql.RelateStatement) (int
met = _UPDATE
}
if err = d.merge(ctx, stm.Data); err != nil {
if err = d.merge(ctx, met, stm.Data); err != nil {
return nil, err
}

View file

@ -137,6 +137,7 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int
var ok bool
var err error
var met = _SELECT
defer d.close()
@ -148,12 +149,10 @@ func (d *document) runSelect(ctx context.Context, stm *sql.SelectStatement) (int
return nil, nil
}
if d.doc == nil {
if ok, err = d.allow(ctx, _SELECT); err != nil {
return nil, err
} else if ok == false {
return nil, nil
}
if ok, err = d.allow(ctx, met); err != nil {
return nil, err
} else if ok == false {
return nil, nil
}
if ok, err = d.check(ctx, stm.Cond); err != nil {

View file

@ -19,7 +19,6 @@ import (
"context"
"github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/keys"
)
@ -135,7 +134,7 @@ func (d *document) table(ctx context.Context, when method) (err error) {
func (d *document) tableDelete(ctx context.Context, id *sql.Thing, exp sql.Fields) (err error) {
stm := &sql.DeleteStatement{
KV: cnf.Settings.DB.Base,
KV: d.key.KV,
NS: d.key.NS,
DB: d.key.DB,
What: sql.Exprs{id},
@ -163,7 +162,7 @@ func (d *document) tableUpdate(ctx context.Context, id *sql.Thing, exp sql.Field
}
stm := &sql.UpdateStatement{
KV: cnf.Settings.DB.Base,
KV: d.key.KV,
NS: d.key.NS,
DB: d.key.DB,
What: sql.Exprs{id},

View file

@ -125,7 +125,7 @@ func (d *document) runUpdate(ctx context.Context, stm *sql.UpdateStatement) (int
return nil, nil
}
if err = d.merge(ctx, stm.Data); err != nil {
if err = d.merge(ctx, met, stm.Data); err != nil {
return nil, err
}

View file

@ -98,7 +98,7 @@ func (d *document) runUpsert(ctx context.Context, stm *sql.UpsertStatement) (int
return nil, nil
}
if err = d.merge(ctx, nil); err != nil {
if err = d.merge(ctx, met, nil); err != nil {
return nil, err
}

View file

@ -17,92 +17,209 @@ package db
import (
"context"
"github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/data"
"github.com/abcum/surreal/util/diff"
)
func (d *document) cold(ctx context.Context) (doc *data.Doc, err error) {
// If we are authenticated using DB, NS,
// or KV permissions level, then we can
// return the document without copying.
if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok {
if k < cnf.AuthSC {
return d.initial, nil
}
}
// Otherwise, we need to create a copy
// of the document so that we can add
// and remove fields before outputting.
doc = d.initial.Copy()
err = d.perms(ctx, doc)
return
}
func (d *document) cnow(ctx context.Context) (doc *data.Doc, err error) {
// If we are authenticated using DB, NS,
// or KV permissions level, then we can
// return the document without copying.
if k, ok := ctx.Value(ctxKeyKind).(cnf.Kind); ok {
if k < cnf.AuthSC {
return d.current, nil
}
}
// Otherwise, we need to create a copy
// of the document so that we can add
// and remove fields before outputting.
doc = d.current.Copy()
err = d.perms(ctx, doc)
return
}
func (d *document) diffs(initial, current *data.Doc) *data.Doc {
a, _ := initial.Data().(map[string]interface{})
b, _ := current.Data().(map[string]interface{})
if c := diff.Diff(a, b); len(c) > 0 {
return data.Consume(c)
}
return data.Consume(nil)
}
func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Token) (interface{}, error) {
var exps sql.Fields
var grps sql.Groups
switch stm := stm.(type) {
case *sql.LiveStatement:
exps = stm.Expr
case *sql.SelectStatement:
exps = stm.Expr
grps = stm.Group
}
var doc *data.Doc
// If there are no field expressions
// then this was not a LIVE or SELECT
// query, and therefore the query will
// have an output format specified.
for _, v := range stm.Expr {
if _, ok := v.Expr.(*sql.All); ok {
doc = d.current
break
}
}
if doc == nil {
doc = data.New()
}
for _, e := range stm.Expr {
switch v := e.Expr.(type) {
case *sql.All:
break
default:
// If the query has a GROUP BY expression
// then let's check if this is an aggregate
// function, and if it is then pass the
// first argument directly through.
if len(stm.Group) > 0 {
if f, ok := e.Expr.(*sql.FuncExpression); ok && f.Aggr {
v, err := d.i.e.fetch(ctx, f.Args[0], d.current)
if err != nil {
return nil, err
}
doc.Set(v, f.String())
continue
}
}
// Otherwise treat the field normally, and
// calculate the value to be inserted into
// the final output document.
v, err := d.i.e.fetch(ctx, v, d.current)
if err != nil {
return nil, err
}
switch v {
case d.current:
doc.Set(nil, e.Field)
default:
doc.Set(v, e.Field)
}
}
}
return doc.Data(), nil
default:
if len(exps) == 0 {
switch output {
default:
return nil, nil
case sql.DIFF:
return d.diff().Data(), nil
old, err := d.cold(ctx)
if err != nil {
return nil, err
}
now, err := d.cnow(ctx)
if err != nil {
return nil, err
}
return d.diffs(old, now).Data(), nil
case sql.AFTER:
return d.current.Data(), nil
doc, err := d.cnow(ctx)
if err != nil {
return nil, err
}
return doc.Data(), nil
case sql.BEFORE:
return d.initial.Data(), nil
doc, err := d.cold(ctx)
if err != nil {
return nil, err
}
return doc.Data(), nil
case sql.BOTH:
old, err := d.cold(ctx)
if err != nil {
return nil, err
}
now, err := d.cnow(ctx)
if err != nil {
return nil, err
}
return map[string]interface{}{
"after": d.current.Data(),
"before": d.initial.Data(),
"after": now.Data(),
"before": old.Data(),
}, nil
}
}
// But if there are field expresions
// then this query is a LIVE or SELECT
// query, and we must output only the
// desired fields in the output.
var out = data.New()
doc, err := d.cnow(ctx)
if err != nil {
return nil, err
}
for _, e := range exps {
if _, ok := e.Expr.(*sql.All); ok {
out = doc
break
}
}
for _, e := range exps {
switch v := e.Expr.(type) {
case *sql.All:
break
default:
// If the query has a GROUP BY expression
// then let's check if this is an aggregate
// function, and if it is then pass the
// first argument directly through.
if len(grps) > 0 {
if f, ok := e.Expr.(*sql.FuncExpression); ok && f.Aggr {
v, err := d.i.e.fetch(ctx, f.Args[0], doc)
if err != nil {
return nil, err
}
out.Set(v, f.String())
continue
}
}
// Otherwise treat the field normally, and
// calculate the value to be inserted into
// the final output document.
v, err := d.i.e.fetch(ctx, v, doc)
if err != nil {
return nil, err
}
switch v {
case doc:
out.Set(nil, e.Field)
default:
out.Set(v, e.Field)
}
}
}
return out.Data(), nil
}