From fb256df42bcebbf4ce84587418c7154878653934 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Fri, 20 Apr 2018 23:40:52 +0100 Subject: [PATCH] Prevent infinite loops with nested subqueries --- db/context.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ db/create.go | 2 ++ db/delete.go | 2 ++ db/insert.go | 2 ++ db/relate.go | 2 ++ db/select.go | 2 ++ db/update.go | 2 ++ db/upsert.go | 2 ++ db/vars.go | 5 +++++ 9 files changed, 67 insertions(+) create mode 100644 db/context.go diff --git a/db/context.go b/db/context.go new file mode 100644 index 00000000..9f2da37c --- /dev/null +++ b/db/context.go @@ -0,0 +1,48 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" +) + +func vers(ctx context.Context) int { + + v := ctx.Value(ctxKeyDive) + + switch v { + case nil: + return 0 + default: + return v.(int) + } + +} + +func dive(ctx context.Context) context.Context { + + v := ctx.Value(ctxKeyDive) + + switch v { + case nil: + return context.WithValue(ctx, ctxKeyDive, 1) + default: + if v.(int) > maxRecursiveQueries { + panic(errRecursiveOverload) + } + return context.WithValue(ctx, ctxKeyDive, v.(int)+1) + } + +} diff --git a/db/create.go b/db/create.go index 7a171022..a27c34f5 100644 --- a/db/create.go +++ b/db/create.go @@ -76,6 +76,8 @@ func (e *executor) executeCreate(ctx context.Context, stm *sql.CreateStatement) func (e *executor) fetchCreate(ctx context.Context, stm *sql.CreateStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/delete.go b/db/delete.go index 38e98e7b..5a9be43c 100644 --- a/db/delete.go +++ b/db/delete.go @@ -75,6 +75,8 @@ func (e *executor) executeDelete(ctx context.Context, stm *sql.DeleteStatement) func (e *executor) fetchDelete(ctx context.Context, stm *sql.DeleteStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/insert.go b/db/insert.go index d566ed6d..e5064967 100644 --- a/db/insert.go +++ b/db/insert.go @@ -54,6 +54,8 @@ func (e *executor) executeInsert(ctx context.Context, stm *sql.InsertStatement) func (e *executor) fetchInsert(ctx context.Context, stm *sql.InsertStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/relate.go b/db/relate.go index 69e2574f..ef8442aa 100644 --- a/db/relate.go +++ b/db/relate.go @@ -68,6 +68,8 @@ func (e *executor) executeRelate(ctx context.Context, stm *sql.RelateStatement) func (e *executor) fetchRelate(ctx context.Context, stm *sql.RelateStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/select.go b/db/select.go index 7f23b5a6..231e2d2f 100644 --- a/db/select.go +++ b/db/select.go @@ -84,6 +84,8 @@ func (e *executor) executeSelect(ctx context.Context, stm *sql.SelectStatement) func (e *executor) fetchSelect(ctx context.Context, stm *sql.SelectStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/update.go b/db/update.go index e0dd7ab9..0e93d3b2 100644 --- a/db/update.go +++ b/db/update.go @@ -75,6 +75,8 @@ func (e *executor) executeUpdate(ctx context.Context, stm *sql.UpdateStatement) func (e *executor) fetchUpdate(ctx context.Context, stm *sql.UpdateStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/upsert.go b/db/upsert.go index 1564646c..2b9fc4dd 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -54,6 +54,8 @@ func (e *executor) executeUpsert(ctx context.Context, stm *sql.UpsertStatement) func (e *executor) fetchUpsert(ctx context.Context, stm *sql.UpsertStatement, doc *data.Doc) (interface{}, error) { + ctx = dive(ctx) + if doc != nil { vars := data.New() vars.Set(doc.Data(), varKeyParent) diff --git a/db/vars.go b/db/vars.go index efd33505..029e724c 100644 --- a/db/vars.go +++ b/db/vars.go @@ -38,6 +38,7 @@ const ( ctxKeyId = "id" ctxKeyNs = "ns" ctxKeyDb = "db" + ctxKeyDive = "dive" ctxKeyVars = "vars" ctxKeySubs = "subs" ctxKeySpec = "spec" @@ -71,6 +72,10 @@ var ( // to process each query statement concurrently. workerCount = runtime.NumCPU() * 2 + // maxRecursiveQueries specifies how many queries will be + // processed recursively before the query is cancelled. + maxRecursiveQueries = 50 + // queryIdentFailed occurs when a permission query asks // for a field, meaning a document has to be fetched. queryIdentFailed = errors.New("Found ident but no doc available")