Prevent infinite loops with nested subqueries

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-20 23:40:52 +01:00
parent d2a451345a
commit fb256df42b
9 changed files with 67 additions and 0 deletions

48
db/context.go Normal file
View file

@ -0,0 +1,48 @@
// Copyright © 2016 Abcum Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package db
import (
"context"
)
func vers(ctx context.Context) int {
v := ctx.Value(ctxKeyDive)
switch v {
case nil:
return 0
default:
return v.(int)
}
}
func dive(ctx context.Context) context.Context {
v := ctx.Value(ctxKeyDive)
switch v {
case nil:
return context.WithValue(ctx, ctxKeyDive, 1)
default:
if v.(int) > maxRecursiveQueries {
panic(errRecursiveOverload)
}
return context.WithValue(ctx, ctxKeyDive, v.(int)+1)
}
}

View file

@ -76,6 +76,8 @@ func (e *executor) executeCreate(ctx context.Context, stm *sql.CreateStatement)
func (e *executor) fetchCreate(ctx context.Context, stm *sql.CreateStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchCreate(ctx context.Context, stm *sql.CreateStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -75,6 +75,8 @@ func (e *executor) executeDelete(ctx context.Context, stm *sql.DeleteStatement)
func (e *executor) fetchDelete(ctx context.Context, stm *sql.DeleteStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchDelete(ctx context.Context, stm *sql.DeleteStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -54,6 +54,8 @@ func (e *executor) executeInsert(ctx context.Context, stm *sql.InsertStatement)
func (e *executor) fetchInsert(ctx context.Context, stm *sql.InsertStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchInsert(ctx context.Context, stm *sql.InsertStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -68,6 +68,8 @@ func (e *executor) executeRelate(ctx context.Context, stm *sql.RelateStatement)
func (e *executor) fetchRelate(ctx context.Context, stm *sql.RelateStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchRelate(ctx context.Context, stm *sql.RelateStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -84,6 +84,8 @@ func (e *executor) executeSelect(ctx context.Context, stm *sql.SelectStatement)
func (e *executor) fetchSelect(ctx context.Context, stm *sql.SelectStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchSelect(ctx context.Context, stm *sql.SelectStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -75,6 +75,8 @@ func (e *executor) executeUpdate(ctx context.Context, stm *sql.UpdateStatement)
func (e *executor) fetchUpdate(ctx context.Context, stm *sql.UpdateStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchUpdate(ctx context.Context, stm *sql.UpdateStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -54,6 +54,8 @@ func (e *executor) executeUpsert(ctx context.Context, stm *sql.UpsertStatement)
func (e *executor) fetchUpsert(ctx context.Context, stm *sql.UpsertStatement, doc *data.Doc) (interface{}, error) { func (e *executor) fetchUpsert(ctx context.Context, stm *sql.UpsertStatement, doc *data.Doc) (interface{}, error) {
ctx = dive(ctx)
if doc != nil { if doc != nil {
vars := data.New() vars := data.New()
vars.Set(doc.Data(), varKeyParent) vars.Set(doc.Data(), varKeyParent)

View file

@ -38,6 +38,7 @@ const (
ctxKeyId = "id" ctxKeyId = "id"
ctxKeyNs = "ns" ctxKeyNs = "ns"
ctxKeyDb = "db" ctxKeyDb = "db"
ctxKeyDive = "dive"
ctxKeyVars = "vars" ctxKeyVars = "vars"
ctxKeySubs = "subs" ctxKeySubs = "subs"
ctxKeySpec = "spec" ctxKeySpec = "spec"
@ -71,6 +72,10 @@ var (
// to process each query statement concurrently. // to process each query statement concurrently.
workerCount = runtime.NumCPU() * 2 workerCount = runtime.NumCPU() * 2
// maxRecursiveQueries specifies how many queries will be
// processed recursively before the query is cancelled.
maxRecursiveQueries = 50
// queryIdentFailed occurs when a permission query asks // queryIdentFailed occurs when a permission query asks
// for a field, meaning a document has to be fetched. // for a field, meaning a document has to be fetched.
queryIdentFailed = errors.New("Found ident but no doc available") queryIdentFailed = errors.New("Found ident but no doc available")