From 1f30035899af8296f403d66f9341d7102556dacf Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Sat, 21 Apr 2018 11:32:13 +0100 Subject: [PATCH] Add SQL FETCH functionality to SELECT statements --- db/fetch.go | 42 ++++++++++++++++++---------------- db/select_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++ db/yield.go | 46 ++++++++++++++++++++++++++++++++++--- sql/ast.go | 9 ++++++++ sql/cork.go | 22 ++++++++++++++++++ sql/select.go | 50 ++++++++++++++++++++++++++++++++++++++++ sql/string.go | 26 ++++++++++++++++++++- sql/tokens.go | 2 ++ util/data/data.go | 25 +++++++++++--------- util/data/data_test.go | 2 +- 10 files changed, 240 insertions(+), 36 deletions(-) diff --git a/db/fetch.go b/db/fetch.go index e831a7c1..fe0b95f1 100644 --- a/db/fetch.go +++ b/db/fetch.go @@ -84,17 +84,18 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o return val, queryIdentFailed case doc != nil: - doc.Fetch(func(key string, val interface{}) interface{} { - switch res := val.(type) { - case []interface{}: - val, _ = e.fetchArray(ctx, res, doc) - return val - case *sql.Thing: - val, _ = e.fetchThing(ctx, res, doc) - return val - default: - return val + doc.Fetch(func(key string, val interface{}, path []string) interface{} { + if len(path) > 0 { + switch res := val.(type) { + case []interface{}: + val, _ = e.fetchArray(ctx, res, doc) + return val + case *sql.Thing: + val, _ = e.fetchThing(ctx, res, doc) + return val + } } + return val }) return e.fetch(ctx, doc.Get(val.ID).Data(), doc) @@ -107,17 +108,18 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o if obj, ok := ctx.Value(s).(*data.Doc); ok { - obj.Fetch(func(key string, val interface{}) interface{} { - switch res := val.(type) { - case []interface{}: - val, _ = e.fetchArray(ctx, res, doc) - return val - case *sql.Thing: - val, _ = e.fetchThing(ctx, res, doc) - return val - default: - return val + obj.Fetch(func(key string, val interface{}, path []string) interface{} { + if len(path) > 0 { + switch res := val.(type) { + case []interface{}: + val, _ = e.fetchArray(ctx, res, doc) + return val + case *sql.Thing: + val, _ = e.fetchThing(ctx, res, doc) + return val + } } + return val }) if res := obj.Get(val.ID).Data(); res != nil { diff --git a/db/select_test.go b/db/select_test.go index a346cfb5..c514b5d6 100644 --- a/db/select_test.go +++ b/db/select_test.go @@ -1766,6 +1766,58 @@ func TestSelect(t *testing.T) { }) + Convey("Fetch records using a fetchplan to fetch remote records easily", t, func() { + + setupDB() + + txt := ` + USE NS test DB test; + CREATE person:test SET + one=tester:test, + mult=[], + mult+=temper:one, + mult+=temper:two, + mult+=temper:tre + ; + CREATE tester:test SET tags=["some","tags"]; + CREATE temper:one SET tester=tester:test; + CREATE temper:two SET tester=tester:test; + CREATE temper:tre SET tester=tester:test; + SELECT * FROM person:test FETCH one, mult, mult.*.tester; + ` + + res, err := Execute(setupKV(), txt, nil) + So(err, ShouldBeNil) + So(res, ShouldHaveLength, 7) + So(res[1].Result, ShouldHaveLength, 1) + So(res[2].Result, ShouldHaveLength, 1) + So(res[3].Result, ShouldHaveLength, 1) + So(res[4].Result, ShouldHaveLength, 1) + So(res[5].Result, ShouldHaveLength, 1) + So(res[6].Result, ShouldHaveLength, 1) + So(data.Consume(res[6].Result[0]).Get("meta.id").Data(), ShouldEqual, "test") + So(data.Consume(res[6].Result[0]).Get("meta.tb").Data(), ShouldEqual, "person") + So(data.Consume(res[6].Result[0]).Get("one.meta.id").Data(), ShouldEqual, "test") + So(data.Consume(res[6].Result[0]).Get("one.meta.tb").Data(), ShouldEqual, "tester") + So(data.Consume(res[6].Result[0]).Get("one.tags").Data(), ShouldResemble, []interface{}{"some", "tags"}) + So(data.Consume(res[6].Result[0]).Get("mult[0].meta.id").Data(), ShouldEqual, "one") + So(data.Consume(res[6].Result[0]).Get("mult[0].meta.tb").Data(), ShouldEqual, "temper") + So(data.Consume(res[6].Result[0]).Get("mult[0].tester.meta.id").Data(), ShouldEqual, "test") + So(data.Consume(res[6].Result[0]).Get("mult[0].tester.meta.tb").Data(), ShouldEqual, "tester") + So(data.Consume(res[6].Result[0]).Get("mult[0].tester.tags").Data(), ShouldResemble, []interface{}{"some", "tags"}) + So(data.Consume(res[6].Result[0]).Get("mult[1].meta.id").Data(), ShouldEqual, "two") + So(data.Consume(res[6].Result[0]).Get("mult[1].meta.tb").Data(), ShouldEqual, "temper") + So(data.Consume(res[6].Result[0]).Get("mult[1].tester.meta.id").Data(), ShouldEqual, "test") + So(data.Consume(res[6].Result[0]).Get("mult[1].tester.meta.tb").Data(), ShouldEqual, "tester") + So(data.Consume(res[6].Result[0]).Get("mult[1].tester.tags").Data(), ShouldResemble, []interface{}{"some", "tags"}) + So(data.Consume(res[6].Result[0]).Get("mult[2].meta.id").Data(), ShouldEqual, "tre") + So(data.Consume(res[6].Result[0]).Get("mult[2].meta.tb").Data(), ShouldEqual, "temper") + So(data.Consume(res[6].Result[0]).Get("mult[2].tester.meta.id").Data(), ShouldEqual, "test") + So(data.Consume(res[6].Result[0]).Get("mult[2].tester.meta.tb").Data(), ShouldEqual, "tester") + So(data.Consume(res[6].Result[0]).Get("mult[2].tester.tags").Data(), ShouldResemble, []interface{}{"some", "tags"}) + + }) + Convey("Version records using a datetime", t, func() { setupDB() diff --git a/db/yield.go b/db/yield.go index 9f922aef..f735a58f 100644 --- a/db/yield.go +++ b/db/yield.go @@ -88,6 +88,7 @@ func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Toke var exps sql.Fields var grps sql.Groups + var fchs sql.Fetchs switch stm := stm.(type) { case *sql.LiveStatement: @@ -95,6 +96,7 @@ func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Toke case *sql.SelectStatement: exps = stm.Expr grps = stm.Group + fchs = stm.Fetch } // If there are no field expressions @@ -170,6 +172,10 @@ func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Toke return nil, err } + // First of all, check to see if an ALL + // expression has been specified, and if + // it has then use the full document. + for _, e := range exps { if _, ok := e.Expr.(*sql.All); ok { out = doc @@ -177,6 +183,10 @@ func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Toke } } + // Next let's see the field expressions + // which have been requested, and add + // these to the output document. + for _, e := range exps { switch v := e.Expr.(type) { @@ -204,22 +214,52 @@ func (d *document) yield(ctx context.Context, stm sql.Statement, output sql.Toke // calculate the value to be inserted into // the final output document. - v, err := d.i.e.fetch(ctx, v, doc) + o, err := d.i.e.fetch(ctx, v, doc) if err != nil { return nil, err } - switch v { + switch o { case doc: out.Set(nil, e.Field) default: - out.Set(v, e.Field) + out.Set(o, e.Field) } } } + // Finally let's see if there are any + // FETCH expressions, so that we can + // follow links to other records. + + for _, e := range fchs { + + switch v := e.Expr.(type) { + case *sql.All: + break + case *sql.Ident: + + out.Walk(func(key string, val interface{}) error { + + switch res := val.(type) { + case []interface{}: + val, _ = d.i.e.fetchArray(ctx, res, doc) + out.Set(val, key) + case *sql.Thing: + val, _ = d.i.e.fetchThing(ctx, res, doc) + out.Set(val, key) + } + + return nil + + }, v.ID) + + } + + } + return out.Data(), nil } diff --git a/sql/ast.go b/sql/ast.go index 009944a8..488ad3bf 100644 --- a/sql/ast.go +++ b/sql/ast.go @@ -173,6 +173,7 @@ type SelectStatement struct { Order Orders `cork:"order" codec:"order"` Limit Expr `cork:"limit" codec:"limit"` Start Expr `cork:"start" codec:"start"` + Fetch Fetchs `cork:"fetch" codec:"fetch"` Version Expr `cork:"version" codec:"version"` Timeout time.Duration `cork:"timeout" codec:"timeout"` Parallel int `cork:"parallel" codec:"parallel"` @@ -524,6 +525,14 @@ type Order struct { // Orders represents multiple ORDER BY clauses. type Orders []*Order +// Fetch represents a FETCH AS clause. +type Fetch struct { + Expr Expr +} + +// Fetchs represents multiple FETCH AS clauses. +type Fetchs []*Fetch + // -------------------------------------------------- // Expressions // -------------------------------------------------- diff --git a/sql/cork.go b/sql/cork.go index 484b41c4..174abd7f 100644 --- a/sql/cork.go +++ b/sql/cork.go @@ -234,6 +234,28 @@ func (this *Order) UnmarshalCORK(r *cork.Reader) (err error) { return } +// -------------------------------------------------- +// FETCH +// -------------------------------------------------- + +func init() { + cork.Register(&Fetch{}) +} + +func (this *Fetch) ExtendCORK() byte { + return 0x11 +} + +func (this *Fetch) MarshalCORK(w *cork.Writer) (dst []byte, err error) { + w.EncodeAny(this.Expr) + return +} + +func (this *Fetch) UnmarshalCORK(r *cork.Reader) (err error) { + r.DecodeAny(&this.Expr) + return +} + // ################################################## // ################################################## // ################################################## diff --git a/sql/select.go b/sql/select.go index 56cc2e9a..32641d90 100644 --- a/sql/select.go +++ b/sql/select.go @@ -55,6 +55,10 @@ func (p *parser) parseSelectStatement() (stmt *SelectStatement, err error) { return nil, err } + if stmt.Fetch, err = p.parseFetch(); err != nil { + return nil, err + } + if stmt.Version, err = p.parseVersion(); err != nil { return nil, err } @@ -313,6 +317,52 @@ func (p *parser) parseStart() (Expr, error) { } +func (p *parser) parseFetch() (mul Fetchs, err error) { + + // The next token that we expect to see is a + // GROUP token, and if we don't find one then + // return nil, with no error. + + if _, _, exi := p.mightBe(FETCH); !exi { + return nil, nil + } + + for { + + var tok Token + var lit string + + one := &Fetch{} + + tok, lit, err = p.shouldBe(IDENT, EXPR) + if err != nil { + return nil, &ParseError{Found: lit, Expected: []string{"field name"}} + } + + one.Expr, err = p.declare(tok, lit) + if err != nil { + return nil, err + } + + // Append the single expression to the array + // of return statement expressions. + + mul = append(mul, one) + + // Check to see if the next token is a comma + // and if not, then break out of the loop, + // otherwise repeat until we find no comma. + + if _, _, exi := p.mightBe(COMMA); !exi { + break + } + + } + + return + +} + func (p *parser) parseVersion() (Expr, error) { if _, _, exi := p.mightBe(VERSION, ON); !exi { diff --git a/sql/string.go b/sql/string.go index 33ee5ef0..50f8bda7 100644 --- a/sql/string.go +++ b/sql/string.go @@ -214,7 +214,7 @@ func (this KillStatement) String() string { } func (this SelectStatement) String() string { - return print("SELECT %v FROM %v%v%v%v%v%v%v%v", + return print("SELECT %v FROM %v%v%v%v%v%v%v%v%v", this.Expr, this.What, maybe(this.Cond != nil, print(" WHERE %v", this.Cond)), @@ -222,6 +222,7 @@ func (this SelectStatement) String() string { this.Order, maybe(this.Limit != nil, print(" LIMIT %v", this.Limit)), maybe(this.Start != nil, print(" START %v", this.Start)), + this.Fetch, maybe(this.Version != nil, print(" VERSION %v", this.Version)), maybe(this.Timeout > 0, print(" TIMEOUT %v", this.Timeout.String())), ) @@ -526,6 +527,29 @@ func (this Order) String() string { ) } +// --------------------------------------------- +// Fetch +// --------------------------------------------- + +func (this Fetchs) String() string { + if len(this) == 0 { + return "" + } + m := make([]string, len(this)) + for k, v := range this { + m[k] = v.String() + } + return print(" FETCH %v", + strings.Join(m, ", "), + ) +} + +func (this Fetch) String() string { + return print("%v", + this.Expr, + ) +} + // --------------------------------------------- // Model // --------------------------------------------- diff --git a/sql/tokens.go b/sql/tokens.go index af6f4c80..d23f0cd3 100644 --- a/sql/tokens.go +++ b/sql/tokens.go @@ -133,6 +133,7 @@ const ( EVENT EXPUNGE FALSE + FETCH FIELD FOR FROM @@ -304,6 +305,7 @@ var tokens = [...]string{ EVENT: "EVENT", EXPUNGE: "EXPUNGE", FALSE: "FALSE", + FETCH: "FETCH", FIELD: "FIELD", FOR: "FOR", FROM: "FROM", diff --git a/util/data/data.go b/util/data/data.go index a67f3654..f9ed2457 100644 --- a/util/data/data.go +++ b/util/data/data.go @@ -43,7 +43,7 @@ type Doc struct { } // Fetcher is used when fetching values. -type Fetcher func(key string, val interface{}) interface{} +type Fetcher func(key string, val interface{}, path []string) interface{} // Iterator is used when iterating over items. type Iterator func(key string, val interface{}) error @@ -430,16 +430,16 @@ func (d *Doc) Exists(path ...string) bool { } if r == one { - if d.call != nil && len(path[k+1:]) > 0 { - c[0] = d.call(p, c[0]) + if d.call != nil { + c[0] = d.call(p, c[0], path[k+1:]) } return ConsumeWithFetch(c[0], d.call).Exists(path[k+1:]...) } if r == many { for _, v := range c { - if d.call != nil && len(path[k+1:]) > 0 { - v = d.call(p, v) + if d.call != nil { + v = d.call(p, v, path[k+1:]) } if !ConsumeWithFetch(v, d.call).Exists(path[k+1:]...) { return false @@ -494,8 +494,8 @@ func (d *Doc) Get(path ...string) *Doc { if m, ok := object.(map[string]interface{}); ok { switch p { default: - if d.call != nil && len(path[k+1:]) > 0 { - object = d.call(p, m[p]) + if d.call != nil { + object = d.call(p, m[p], path[k+1:]) } else { object = m[p] } @@ -518,8 +518,8 @@ func (d *Doc) Get(path ...string) *Doc { } if r == one { - if d.call != nil && len(path[k+1:]) > 0 { - c[0] = d.call(p, c[0]) + if d.call != nil { + c[0] = d.call(p, c[0], path[k+1:]) } return ConsumeWithFetch(c[0], d.call).Get(path[k+1:]...) } @@ -527,8 +527,8 @@ func (d *Doc) Get(path ...string) *Doc { if r == many { out := []interface{}{} for _, v := range c { - if d.call != nil && len(path[k+1:]) > 0 { - v = d.call(p, v) + if d.call != nil { + v = d.call(p, v, path[k+1:]) } res := ConsumeWithFetch(v, d.call).Get(path[k+1:]...) out = append(out, res.data) @@ -610,6 +610,7 @@ func (d *Doc) Set(value interface{}, path ...string) (*Doc, error) { if k == len(path)-1 { a[i[0]] = value object = a[i[0]] + continue } else { return ConsumeWithFetch(a[i[0]], d.call).Set(value, path[k+1:]...) } @@ -700,6 +701,7 @@ func (d *Doc) Del(path ...string) error { if r == one { if k == len(path)-1 { d.Set(c, path[:len(path)-1]...) + continue } else { if len(c) != 0 { return ConsumeWithFetch(c[0], d.call).Del(path[k+1:]...) @@ -710,6 +712,7 @@ func (d *Doc) Del(path ...string) error { if r == many { if k == len(path)-1 { d.Set(c, path[:len(path)-1]...) + continue } else { for _, v := range c { ConsumeWithFetch(v, d.call).Del(path[k+1:]...) diff --git a/util/data/data_test.go b/util/data/data_test.go index 2f070aa4..3bdee96e 100644 --- a/util/data/data_test.go +++ b/util/data/data_test.go @@ -303,7 +303,7 @@ func TestOperations(t *testing.T) { // ---------------------------------------------------------------------------------------------------- Convey("Can set fetcher function", t, func() { - doc.Fetch(func(key string, val interface{}) interface{} { + doc.Fetch(func(key string, val interface{}, path []string) interface{} { return val }) })