diff --git a/sql/create.go b/sql/create.go index 666b14ad..27578776 100644 --- a/sql/create.go +++ b/sql/create.go @@ -18,11 +18,9 @@ func (p *parser) parseCreateStatement() (stmt *CreateStatement, err error) { stmt = &CreateStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) - - _, _, _ = p.mightBe(INTO) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } if stmt.What, err = p.parseWhat(); err != nil { return nil, err diff --git a/sql/delete.go b/sql/delete.go index 67045f08..66add728 100644 --- a/sql/delete.go +++ b/sql/delete.go @@ -18,11 +18,9 @@ func (p *parser) parseDeleteStatement() (stmt *DeleteStatement, err error) { stmt = &DeleteStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) - - _, _, _ = p.mightBe(AND) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } _, _, stmt.Hard = p.mightBe(EXPUNGE) diff --git a/sql/error.go b/sql/error.go index 2e8ab7e0..0206bc98 100644 --- a/sql/error.go +++ b/sql/error.go @@ -27,6 +27,30 @@ func (e *EmptyError) Error() string { return fmt.Sprint("Your SQL query is empty") } +// EntryError represents an error that occured when switching access. +type EntryError struct{} + +// Error returns the string representation of the error. +func (e *EntryError) Error() string { + return fmt.Sprint("You don't have permission to access this resource") +} + +// QueryError represents an error that occured when switching access. +type QueryError struct{} + +// Error returns the string representation of the error. +func (e *QueryError) Error() string { + return fmt.Sprint("You don't have permission to perform this query type") +} + +// BlankError represents an error that occured when switching access. +type BlankError struct{} + +// Error returns the string representation of the error. +func (e *BlankError) Error() string { + return fmt.Sprint("You need to specify a namespace and a database to use") +} + // ParseError represents an error that occurred during parsing. type ParseError struct { Found string diff --git a/sql/field.go b/sql/field.go index 912e23a7..730b2fc0 100644 --- a/sql/field.go +++ b/sql/field.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineFieldStatement() (stmt *DefineFieldStatement, err er stmt = &DefineFieldStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err @@ -137,9 +137,9 @@ func (p *parser) parseRemoveFieldStatement() (stmt *RemoveFieldStatement, err er stmt = &RemoveFieldStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err diff --git a/sql/index.go b/sql/index.go index 8757d840..9714bbc4 100644 --- a/sql/index.go +++ b/sql/index.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineIndexStatement() (stmt *DefineIndexStatement, err er stmt = &DefineIndexStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err @@ -56,9 +56,9 @@ func (p *parser) parseRemoveIndexStatement() (stmt *RemoveIndexStatement, err er stmt = &RemoveIndexStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err diff --git a/sql/let.go b/sql/let.go index 5d720ff0..14c07003 100644 --- a/sql/let.go +++ b/sql/let.go @@ -18,9 +18,10 @@ func (p *parser) parseLetStatement() (stmt *LetStatement, err error) { stmt = &LetStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } + // The first part of a LET expression must // always be an identifier, specifying a // variable name to set. diff --git a/sql/options.go b/sql/options.go new file mode 100644 index 00000000..714d5af4 --- /dev/null +++ b/sql/options.go @@ -0,0 +1,124 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "github.com/abcum/fibre" + "github.com/abcum/surreal/cnf" +) + +const ( + // Root access + AuthKV int = iota + // Namespace access + AuthNS + // Database access + AuthDB + // Scoped user access + AuthTB + // No access + AuthNO +) + +// options represents context runtime config. +type options struct { + kind int + auth map[string]string + conf map[string]string +} + +func newOptions(c *fibre.Context) *options { + return &options{ + kind: c.Get("kind").(int), + auth: c.Get("auth").(map[string]string), + conf: c.Get("conf").(map[string]string), + } +} + +func (o *options) get(kind int) (kv, ns, db string, err error) { + + kv = cnf.Settings.DB.Base + ns = o.conf["NS"] + db = o.conf["DB"] + + if o.kind > kind { + err = &QueryError{} + return + } + + if ns == "" || db == "" { + err = &BlankError{} + return + } + + return + +} + +func (o *options) ns(ns string) (err error) { + + // Check to see that the current user has + // the necessary access level to perform + // namespace switching / choosing. + + if o.kind > AuthKV { + return &EntryError{} + } + + // Check to see that the current user has + // the necessary authentcation privileges + // to be able to specify this namespace. + + if o.auth["NS"] != "*" && o.auth["NS"] != ns { + return &EntryError{} + } + + // Specify the NS on the context session, so + // that it is remembered across requests on + // any persistent connections. + + o.conf["NS"] = ns + + return + +} + +func (o *options) db(db string) (err error) { + + // Check to see that the current user has + // the necessary access level to perform + // database switching / choosing. + + if o.kind > AuthNS { + return &EntryError{} + } + + // Check to see that the current user has + // the necessary authentcation privileges + // to be able to specify this namespace. + + if o.auth["DB"] != "*" && o.auth["DB"] != db { + return &EntryError{} + } + + // Specify the DB on the context session, so + // that it is remembered across requests on + // any persistent connections. + + o.conf["DB"] = db + + return + +} diff --git a/sql/parser.go b/sql/parser.go index a9d8aa14..7a829ceb 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -25,6 +25,7 @@ import ( // parser represents a parser. type parser struct { s *scanner + o *options c *fibre.Context v map[string]interface{} buf struct { @@ -37,7 +38,7 @@ type parser struct { // newParser returns a new instance of Parser. func newParser(c *fibre.Context, v map[string]interface{}) *parser { - return &parser{c: c, v: v} + return &parser{c: c, v: v, o: newOptions(c)} } // Parse parses sql from a []byte, string, or io.Reader. diff --git a/sql/relate.go b/sql/relate.go index 10f14e07..d1788eb6 100644 --- a/sql/relate.go +++ b/sql/relate.go @@ -18,9 +18,9 @@ func (p *parser) parseRelateStatement() (stmt *RelateStatement, err error) { stmt = &RelateStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } if stmt.Type, err = p.parseTable(); err != nil { return nil, err diff --git a/sql/return.go b/sql/return.go index f5477565..e3d8715b 100644 --- a/sql/return.go +++ b/sql/return.go @@ -18,7 +18,7 @@ func (p *parser) parseReturnStatement() (stmt *ReturnStatement, err error) { stmt = &ReturnStatement{} - if _, _, _, err = p.o.get(AuthTB); err != nil { + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { return nil, err } diff --git a/sql/rules.go b/sql/rules.go index 6e689b60..02f02d3e 100644 --- a/sql/rules.go +++ b/sql/rules.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineRulesStatement() (stmt *DefineRulesStatement, err er stmt = &DefineRulesStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if _, _, err = p.shouldBe(ON); err != nil { return nil, err @@ -73,9 +73,9 @@ func (p *parser) parseRemoveRulesStatement() (stmt *RemoveRulesStatement, err er stmt = &RemoveRulesStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if _, _, err = p.shouldBe(ON); err != nil { return nil, err diff --git a/sql/scope.go b/sql/scope.go index 2a5e63d1..bda113f1 100644 --- a/sql/scope.go +++ b/sql/scope.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineScopeStatement() (stmt *DefineScopeStatement, err er stmt = &DefineScopeStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err @@ -46,9 +46,9 @@ func (p *parser) parseRemoveScopeStatement() (stmt *RemoveScopeStatement, err er stmt = &RemoveScopeStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err diff --git a/sql/select.go b/sql/select.go index e260b4b6..81616049 100644 --- a/sql/select.go +++ b/sql/select.go @@ -18,9 +18,9 @@ func (p *parser) parseSelectStatement() (stmt *SelectStatement, err error) { stmt = &SelectStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } if stmt.Expr, err = p.parseField(); err != nil { return nil, err diff --git a/sql/table.go b/sql/table.go index b1b3b616..bedcc4eb 100644 --- a/sql/table.go +++ b/sql/table.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineTableStatement() (stmt *DefineTableStatement, err er stmt = &DefineTableStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.What, err = p.parseNames(); err != nil { return nil, err @@ -42,9 +42,9 @@ func (p *parser) parseRemoveTableStatement() (stmt *RemoveTableStatement, err er stmt = &RemoveTableStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.What, err = p.parseNames(); err != nil { return nil, err diff --git a/sql/trans.go b/sql/trans.go index 121f2b48..6621dc2a 100644 --- a/sql/trans.go +++ b/sql/trans.go @@ -18,6 +18,10 @@ func (p *parser) parseBeginStatement() (stmt *BeginStatement, err error) { stmt = &BeginStatement{} + if _, _, _, err = p.o.get(AuthTB); err != nil { + return nil, err + } + _, _, _ = p.mightBe(TRANSACTION) if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil { @@ -32,6 +36,10 @@ func (p *parser) parseCancelStatement() (stmt *CancelStatement, err error) { stmt = &CancelStatement{} + if _, _, _, err = p.o.get(AuthTB); err != nil { + return nil, err + } + _, _, _ = p.mightBe(TRANSACTION) if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil { @@ -46,6 +54,10 @@ func (p *parser) parseCommitStatement() (stmt *CommitStatement, err error) { stmt = &CommitStatement{} + if _, _, _, err = p.o.get(AuthTB); err != nil { + return nil, err + } + _, _, _ = p.mightBe(TRANSACTION) if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil { diff --git a/sql/update.go b/sql/update.go index a8cc320c..5d4fd0a4 100644 --- a/sql/update.go +++ b/sql/update.go @@ -18,9 +18,9 @@ func (p *parser) parseUpdateStatement() (stmt *UpdateStatement, err error) { stmt = &UpdateStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil { + return nil, err + } _, _, _ = p.mightBe(INTO) diff --git a/sql/use.go b/sql/use.go index fa6bf237..384ae2ad 100644 --- a/sql/use.go +++ b/sql/use.go @@ -29,21 +29,29 @@ func (p *parser) parseUseStatement() (stmt *UseStatement, err error) { for { if p.is(tok, NAMESPACE) { + _, stmt.NS, err = p.shouldBe(IDENT, STRING) if err != nil { return nil, &ParseError{Found: stmt.NS, Expected: []string{"namespace name"}} } - // TODO: need to make sure this user can access this NS - p.c.Set("NS", stmt.NS) + + if err = p.o.ns(stmt.NS); err != nil { + return nil, err + } + } if p.is(tok, DATABASE) { + _, stmt.DB, err = p.shouldBe(IDENT, DATE, TIME, STRING, NUMBER, DOUBLE) if err != nil { return nil, &ParseError{Found: stmt.DB, Expected: []string{"database name"}} } - // TODO: need to make sure this user can access this DB - p.c.Set("DB", stmt.DB) + + if err = p.o.db(stmt.DB); err != nil { + return nil, err + } + } tok, _, exi = p.mightBe(NAMESPACE, DATABASE) diff --git a/sql/view.go b/sql/view.go index 77cacc28..0b83a930 100644 --- a/sql/view.go +++ b/sql/view.go @@ -18,9 +18,9 @@ func (p *parser) parseDefineViewStatement() (stmt *DefineViewStatement, err erro stmt = &DefineViewStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err @@ -69,9 +69,9 @@ func (p *parser) parseRemoveViewStatement() (stmt *RemoveViewStatement, err erro stmt = &RemoveViewStatement{} - stmt.KV = p.c.Get("KV").(string) - stmt.NS = p.c.Get("NS").(string) - stmt.DB = p.c.Get("DB").(string) + if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil { + return nil, err + } if stmt.Name, err = p.parseName(); err != nil { return nil, err