surrealpatch/db/iterator.go
2019-06-14 18:33:41 +01:00

1052 lines
20 KiB
Go

// Copyright © 2016 Abcum Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package db
import (
"fmt"
"math"
"sort"
"sync"
"context"
"github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/kvs"
"github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/comp"
"github.com/abcum/surreal/util/data"
"github.com/abcum/surreal/util/fncs"
"github.com/abcum/surreal/util/guid"
"github.com/abcum/surreal/util/ints"
"github.com/abcum/surreal/util/keys"
"github.com/abcum/surreal/util/nums"
"github.com/abcum/surreal/util/rand"
)
type iterator struct {
e *executor
id int
err error
vir bool
stm sql.Statement
res []interface{}
wait sync.WaitGroup
fail chan error
stop chan struct{}
jobs chan *workable
vals chan *doneable
expr sql.Fields
what sql.Exprs
cond sql.Expr
split sql.Idents
group sql.Groups
order sql.Orders
limit int
start int
versn int64
async bool
}
type workable struct {
key *keys.Thing
val kvs.KV
doc *data.Doc
}
type doneable struct {
res interface{}
err error
}
type groupable struct {
doc *data.Doc
ats []interface{}
}
type orderable struct {
doc *data.Doc
ats []interface{}
}
func newIterator(e *executor, ctx context.Context, stm sql.Statement, vir bool) (i *iterator) {
i = iteratorPool.Get().(*iterator)
i.e = e
i.id = rand.Int()
i.err = nil
i.stm = stm
i.vir = vir
i.res = make([]interface{}, 0)
i.wait = sync.WaitGroup{}
i.fail = make(chan error, 1)
i.stop = make(chan struct{})
i.jobs = make(chan *workable, workerCount)
i.vals = make(chan *doneable, workerCount)
// Comment here
i.setupState(ctx)
// Comment here ...
i.setupWorkers(ctx)
return
}
func (i *iterator) Close() {
i.e = nil
i.err = nil
i.stm = nil
i.res = nil
i.fail = nil
i.stop = nil
i.jobs = nil
i.vals = nil
i.expr = nil
i.what = nil
i.cond = nil
i.split = nil
i.group = nil
i.order = nil
i.limit = -1
i.start = -1
i.versn = 0
i.async = false
iteratorPool.Put(i)
}
func (i *iterator) setupState(ctx context.Context) {
i.expr = nil
i.what = nil
i.cond = nil
i.split = nil
i.group = nil
i.order = nil
i.split = nil
i.limit = -1
i.start = -1
i.versn = math.MaxInt64
switch stm := i.stm.(type) {
case *sql.SelectStatement:
i.expr = stm.Expr
i.what = stm.What
i.cond = stm.Cond
i.split = stm.Split
i.group = stm.Group
i.order = stm.Order
i.async = stm.Parallel
case *sql.CreateStatement:
i.what = stm.What
i.async = stm.Parallel
case *sql.UpdateStatement:
i.what = stm.What
i.cond = stm.Cond
i.async = stm.Parallel
case *sql.DeleteStatement:
i.what = stm.What
i.cond = stm.Cond
i.async = stm.Parallel
case *sql.InsertStatement:
i.what = sql.Exprs{stm.Data}
i.async = stm.Parallel
case *sql.UpsertStatement:
i.what = sql.Exprs{stm.Data}
i.async = stm.Parallel
}
if stm, ok := i.stm.(*sql.SelectStatement); ok {
// Fetch and check the LIMIT BY expression
// to see if any parameter specified is valid.
i.limit, i.err = i.e.fetchLimit(ctx, stm.Limit)
if i.err != nil {
close(i.stop)
return
}
// Fetch and check the START AT expression
// to see if any parameter specified is valid.
i.start, i.err = i.e.fetchStart(ctx, stm.Start)
if i.err != nil {
close(i.stop)
return
}
// Fetch and check the VERSION expression to
// see if any parameter specified is valid.
i.versn, i.err = i.e.fetchVersion(ctx, stm.Version)
if i.err != nil {
close(i.stop)
return
}
}
}
func (i *iterator) checkState(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-i.stop:
return false
default:
return true
}
}
func (i *iterator) setupWorkers(ctx context.Context) {
if !i.checkState(ctx) {
return
}
go func(vals <-chan *doneable) {
for v := range vals {
i.receive(v)
}
}(i.vals)
workers := 1
if i.async {
workers = workerCount
}
for w := 1; w <= workers; w++ {
go func(jobs <-chan *workable, vals chan<- *doneable) {
for j := range jobs {
res, err := newDocument(i, j.key, j.val, j.doc).query(ctx, i.stm)
vals <- &doneable{res: res, err: err}
}
}(i.jobs, i.vals)
}
}
func (i *iterator) deliver(key *keys.Thing, val kvs.KV, doc *data.Doc) {
i.wait.Add(1)
i.jobs <- &workable{key: key, val: val, doc: doc}
}
func (i *iterator) receive(val *doneable) {
defer i.wait.Done()
// If an error was received from the
// worker, then set the error if no
// previous iterator error has occured.
if val.err != nil {
select {
case <-i.stop:
return
default:
i.fail <- val.err
close(i.stop)
return
}
}
// Otherwise add the received result
// to the iterator result slice so
// that it is ready for processing.
if val.res != nil {
i.res = append(i.res, val.res)
}
// The statement does not have a limit
// expression specified, so therefore
// we need to load all data before
// stopping the iterator.
if i.limit < 0 {
return
}
// If the statement specified a GROUP
// BY expression, then we need to load
// all data from all sources before
// stopping the iterator.
if len(i.group) > 0 {
return
}
// If the statement specified an ORDER
// BY expression, then we need to load
// all data from all sources before
// stopping the iterator.
if len(i.order) > 0 {
return
}
// Otherwise we can stop the iterator
// early, if we have the necessary
// number of records specified in the
// query statement.
select {
case <-i.stop:
return
default:
if i.start >= 0 {
if len(i.res) == i.limit+i.start {
close(i.stop)
}
} else {
if len(i.res) == i.limit {
close(i.stop)
}
}
}
}
func (i *iterator) processPerms(ctx context.Context, nsv, dbv, tbv string) {
var tb *sql.DefineTableStatement
// If we are authenticated using DB, NS,
// or KV permissions level, then we can
// ignore all permissions checks, but we
// must ensure the TB, DB, and NS exist.
if perm(ctx) < cnf.AuthSC {
// If we do not have a specified table
// value, because we are processing a
// subquery, then there is no need to
// check if the table exists or not.
if len(tbv) == 0 {
return
}
// If this is a select statement then
// there is no need to fetch the table
// to check whether it is a view table.
switch i.stm.(type) {
case *sql.SelectStatement:
return
}
// If it is not a select statement, then
// we need to fetch the table to ensure
// that the table is not a view table.
tb, i.err = i.e.dbo.AddTB(ctx, nsv, dbv, tbv)
if i.err != nil {
close(i.stop)
return
}
// If the table is locked (because it
// has been specified as a view), then
// check to see what query type it is
// and return an error if it attempts
// to alter the table in any way.
if tb.Lock && i.vir == false {
switch i.stm.(type) {
case *sql.CreateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.UpdateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.DeleteStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.RelateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.InsertStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.UpsertStatement:
i.err = &TableError{table: tb.Name.VA}
}
}
if i.err != nil {
close(i.stop)
}
return
}
// If we do not have a specified table
// value, because we are processing a
// subquery, then there is no need to
// check if the table exists or not.
if len(tbv) == 0 {
return
}
// First check that the NS exists, as
// otherwise, the scoped authentication
// request can not do anything.
_, i.err = i.e.dbo.GetNS(ctx, nsv)
if i.err != nil {
close(i.stop)
return
}
// Next check that the DB exists, as
// otherwise, the scoped authentication
// request can not do anything.
_, i.err = i.e.dbo.GetDB(ctx, nsv, dbv)
if i.err != nil {
close(i.stop)
return
}
// Then check that the TB exists, as
// otherwise, the scoped authentication
// request can not do anything.
tb, i.err = i.e.dbo.GetTB(ctx, nsv, dbv, tbv)
if i.err != nil {
close(i.stop)
return
}
// If the table is locked (because it
// has been specified as a view), then
// check to see what query type it is
// and return an error, if it attempts
// to alter the table in any way.
if tb.Lock && i.vir == false {
switch i.stm.(type) {
case *sql.CreateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.UpdateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.DeleteStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.RelateStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.InsertStatement:
i.err = &TableError{table: tb.Name.VA}
case *sql.UpsertStatement:
i.err = &TableError{table: tb.Name.VA}
}
}
if i.err != nil {
close(i.stop)
return
}
// If the table does exist we then try
// to process the relevant permissions
// expression, but only if they don't
// reference any document fields.
switch p := tb.Perms.(type) {
default:
i.err = &PermsError{table: tb.Name.VA}
case *sql.PermExpression:
switch i.stm.(type) {
case *sql.SelectStatement:
i.err = i.e.fetchPerms(ctx, p.Select, tb.Name)
case *sql.CreateStatement:
i.err = i.e.fetchPerms(ctx, p.Create, tb.Name)
case *sql.UpdateStatement:
i.err = i.e.fetchPerms(ctx, p.Update, tb.Name)
case *sql.DeleteStatement:
i.err = i.e.fetchPerms(ctx, p.Delete, tb.Name)
case *sql.RelateStatement:
i.err = i.e.fetchPerms(ctx, p.Create, tb.Name)
case *sql.InsertStatement:
i.err = i.e.fetchPerms(ctx, p.Create, tb.Name)
case *sql.UpsertStatement:
i.err = i.e.fetchPerms(ctx, p.Update, tb.Name)
}
}
if i.err != nil {
close(i.stop)
return
}
return
}
func (i *iterator) processThing(ctx context.Context, key *keys.Thing) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
if i.checkState(ctx) {
i.deliver(key, nil, nil)
}
}
func (i *iterator) processTable(ctx context.Context, key *keys.Table) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
// TODO use indexes to speed up queries
// We need to make use of indexes here
// so that the query speed is improved.
// If an index exists with the correct
// ORDER BY fields then iterate over
// the IDs from the index.
beg := &keys.Thing{KV: key.KV, NS: key.NS, DB: key.DB, TB: key.TB, ID: keys.Ignore}
end := &keys.Thing{KV: key.KV, NS: key.NS, DB: key.DB, TB: key.TB, ID: keys.Suffix}
min, max := beg.Encode(), end.Encode()
for x := 0; ; x = 1 {
var vals []kvs.KV
if !i.checkState(ctx) {
return
}
vals, i.err = i.e.dbo.GetR(ctx, i.versn, min, max, 10000)
if i.err != nil {
close(i.stop)
return
}
// If there are no further records
// fetched from the data layer, then
// return out of this loop iteration.
if x >= len(vals) {
return
}
// If there were at least 1 or 2
// keys-values, then loop over all
// the items and process the records.
for _, val := range vals {
if i.checkState(ctx) {
i.deliver(nil, val, nil)
continue
}
}
// When we loop around, we will use
// the key of the last retrieved key
// to perform the next range request.
beg.Decode(vals[len(vals)-1].Key())
min = append(beg.Encode(), byte(0))
}
}
func (i *iterator) processBatch(ctx context.Context, key *keys.Thing, qry *sql.Batch) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
for _, val := range qry.BA {
// Loop over the items in the batch
// and specify the TB and ID for
// each record.
if i.checkState(ctx) {
key := key.Copy()
key.TB, key.ID = val.TB, val.ID
i.deliver(key, nil, nil)
continue
}
break
}
}
func (i *iterator) processModel(ctx context.Context, key *keys.Thing, qry *sql.Model) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
switch {
case qry.INC == 0:
// If there was no incrementing pattern
// specified for the model, then let's
// generate unique ids for each record.
for j := 1; j <= int(qry.MAX); j++ {
if i.checkState(ctx) {
key := key.Copy()
key.ID = guid.New().String()
i.deliver(key, nil, nil)
continue
}
break
}
case qry.MIN < qry.MAX:
// If an incrementing pattern has been
// specified, then ascend through the
// steps sequentially.
dec := nums.CountPlaces(qry.INC)
for num := qry.MIN; num <= qry.MAX; num = nums.FormatPlaces(num+qry.INC, dec) {
if i.checkState(ctx) {
key := key.Copy()
key.ID = num
i.deliver(key, nil, nil)
continue
}
break
}
case qry.MIN > qry.MAX:
// If an decrementing pattern has been
// specified, then descend through the
// steps sequentially.
dec := nums.CountPlaces(qry.INC)
for num := qry.MIN; num >= qry.MAX; num = nums.FormatPlaces(num-qry.INC, dec) {
if i.checkState(ctx) {
key := key.Copy()
key.ID = num
i.deliver(key, nil, nil)
continue
}
break
}
}
}
func (i *iterator) processOther(ctx context.Context, key *keys.Thing, val []interface{}) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
for _, v := range val {
switch v := v.(type) {
case *sql.Thing:
// If the item is a *sql.Thing then
// this was a subquery which projected
// the ID only, and we can query the
// record further after loading it.
if i.checkState(ctx) {
key := key.Copy()
key.TB, key.ID = v.TB, v.ID
i.deliver(key, nil, nil)
continue
}
default:
switch i.stm.(type) {
case *sql.CreateStatement:
i.fail <- fmt.Errorf("Can not execute CREATE query using value '%v'", val)
case *sql.UpdateStatement:
i.fail <- fmt.Errorf("Can not execute UPDATE query using value '%v'", val)
case *sql.DeleteStatement:
i.fail <- fmt.Errorf("Can not execute DELETE query using value '%v'", val)
case *sql.RelateStatement:
i.fail <- fmt.Errorf("Can not execute RELATE query using value '%v'", val)
}
close(i.stop)
}
break
}
}
func (i *iterator) processQuery(ctx context.Context, key *keys.Thing, val []interface{}) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
for _, v := range val {
switch v := v.(type) {
case *sql.Thing:
// If the item is a *sql.Thing then
// this was a subquery which projected
// the ID only, and we can query the
// record further after loading it.
if i.checkState(ctx) {
key := key.Copy()
key.TB, key.ID = v.TB, v.ID
i.deliver(key, nil, nil)
continue
}
default:
// Otherwise let's just load up all
// of the data so we can process it.
if i.checkState(ctx) {
i.deliver(nil, nil, data.Consume(v))
continue
}
}
break
}
}
func (i *iterator) processArray(ctx context.Context, key *keys.Thing, val []interface{}) {
i.processPerms(ctx, key.NS, key.DB, key.TB)
for _, v := range val {
switch v := v.(type) {
case *sql.Thing:
// If the item is a *sql.Thing then
// this was a subquery, so use the ID.
if i.checkState(ctx) {
key := key.Copy()
key.ID = v.ID
i.deliver(key, nil, nil)
continue
}
case map[string]interface{}:
// If the data item has an ID field,
// then use this as the new record ID.
if fld, ok := v["id"]; ok {
if thg, ok := v["id"].(*sql.Thing); ok {
// If the ID is a *sql.Thing then this
// was a subquery, so use the ID.
if i.checkState(ctx) {
key := key.Copy()
key.ID = thg.ID
i.deliver(key, nil, data.Consume(v))
continue
}
} else {
// If not, then take the whole ID and
// use that as the ID of the new record.
if i.checkState(ctx) {
key := key.Copy()
key.ID = fld
i.deliver(key, nil, data.Consume(v))
continue
}
}
} else {
// If there is no ID field, then create
// a unique id for the new record.
if i.checkState(ctx) {
key := key.Copy()
key.ID = guid.New().String()
i.deliver(key, nil, data.Consume(v))
continue
}
}
}
break
}
}
func (i *iterator) Yield(ctx context.Context) (out []interface{}, err error) {
defer i.Close()
i.wait.Wait()
close(i.jobs)
close(i.vals)
if i.err != nil {
return nil, i.err
}
if i.err == nil {
select {
default:
case i.err = <-i.fail:
return nil, i.err
}
}
if len(i.split) > 0 {
i.res = i.Split(ctx, i.res)
}
if len(i.group) > 0 {
i.res = i.Group(ctx, i.res)
}
if len(i.order) > 0 {
i.res = i.Order(ctx, i.res)
}
if i.start >= 0 {
num := ints.Min(i.start, len(i.res))
i.res = i.res[num:]
}
if i.limit >= 0 {
num := ints.Min(i.limit, len(i.res))
i.res = i.res[:num]
}
return i.res, i.err
}
func (i *iterator) Split(ctx context.Context, arr []interface{}) (out []interface{}) {
for _, s := range i.split {
out = make([]interface{}, 0)
for _, a := range arr {
doc := data.Consume(a)
pth := make([]string, 0)
switch doc.Get(s.VA).Data().(type) {
case []interface{}:
pth = append(pth, s.VA, docKeyAll)
default:
pth = append(pth, s.VA)
}
doc.Walk(func(key string, val interface{}, exi bool) error {
doc := doc.Copy()
doc.Set(val, s.VA)
out = append(out, doc.Data())
return nil
}, pth...)
}
arr = out
}
return out
}
func (i *iterator) Group(ctx context.Context, arr []interface{}) (out []interface{}) {
var grp []*groupable
var col = make(map[string][]interface{})
// Loop through all of the items
// and create a *groupable to
// store the record, and all of
// the attributes in the GROUP BY.
for _, a := range arr {
g := &groupable{
doc: data.Consume(a),
ats: make([]interface{}, len(i.group)),
}
for k, e := range i.group {
g.ats[k], _ = i.e.fetch(ctx, e.Expr, g.doc)
}
grp = append(grp, g)
}
// Group all of the items together
// according to the GROUP by clause.
// We use a string representation of
// the group fields to group records.
for _, s := range grp {
k := fmt.Sprintf("%v", s.ats)
col[k] = append(col[k], s.doc.Data())
}
for _, obj := range col {
doc, all := data.New(), data.Consume(obj)
for _, e := range i.expr {
// If the clause has a GROUP BY expression
// then let's check if this is an aggregate
// function, and if it is then calculate
// the output with the aggregated data.
if f, ok := e.Expr.(*sql.FuncExpression); ok && f.Aggr {
args := make([]interface{}, len(f.Args))
for x := 0; x < len(f.Args); x++ {
if x == 0 {
args[x] = all.Get("*", f.String()).Data()
} else {
args[x], _ = i.e.fetch(ctx, f.Args[x], nil)
}
}
val, _ := fncs.Run(ctx, f.Name, args...)
doc.Set(val, e.Field)
continue
}
// Otherwise if not, then it is a field
// which is also specified in the GROUP BY
// clause, so let's include the first
// value in the aggregated results.
val := all.Get("0", e.Field).Data()
doc.Set(val, e.Field)
}
out = append(out, doc.Data())
}
return
}
func (i *iterator) Order(ctx context.Context, arr []interface{}) (out []interface{}) {
var ord []*orderable
// Loop through all of the items
// and create an *orderable to
// store the record, and all of
// the attributes in the ORDER BY.
for _, a := range arr {
ord = append(ord, &orderable{
doc: data.Consume(a),
ats: make([]interface{}, 0),
})
}
// Sort the *sortable items whilst
// fetching any values which were
// previously not loaded. Cache
// the values on the *orderable.
sort.Slice(ord, func(k, j int) bool {
for x, e := range i.order {
if len(ord[k].ats) <= x {
a, _ := i.e.fetch(ctx, e.Expr, ord[k].doc)
ord[k].ats = append(ord[k].ats, a)
}
if len(ord[j].ats) <= x {
a, _ := i.e.fetch(ctx, e.Expr, ord[j].doc)
ord[j].ats = append(ord[j].ats, a)
}
if c := comp.Comp(ord[k].ats[x], ord[j].ats[x], e); c != 0 {
return (c < 0 && e.Dir) || (c > 0 && !e.Dir)
}
}
return false
})
// Loop over the sorted items and
// add the document data for each
// item to the output array.
for _, s := range ord {
out = append(out, s.doc.Data())
}
return
}