Perform authentication access checks in SQL layer
This commit is contained in:
parent
df7ee71cf6
commit
b720213bd4
18 changed files with 230 additions and 64 deletions
|
@ -18,11 +18,9 @@ func (p *parser) parseCreateStatement() (stmt *CreateStatement, err error) {
|
|||
|
||||
stmt = &CreateStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
|
||||
_, _, _ = p.mightBe(INTO)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.What, err = p.parseWhat(); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -18,11 +18,9 @@ func (p *parser) parseDeleteStatement() (stmt *DeleteStatement, err error) {
|
|||
|
||||
stmt = &DeleteStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
|
||||
_, _, _ = p.mightBe(AND)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, stmt.Hard = p.mightBe(EXPUNGE)
|
||||
|
||||
|
|
24
sql/error.go
24
sql/error.go
|
@ -27,6 +27,30 @@ func (e *EmptyError) Error() string {
|
|||
return fmt.Sprint("Your SQL query is empty")
|
||||
}
|
||||
|
||||
// EntryError represents an error that occured when switching access.
|
||||
type EntryError struct{}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
func (e *EntryError) Error() string {
|
||||
return fmt.Sprint("You don't have permission to access this resource")
|
||||
}
|
||||
|
||||
// QueryError represents an error that occured when switching access.
|
||||
type QueryError struct{}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
func (e *QueryError) Error() string {
|
||||
return fmt.Sprint("You don't have permission to perform this query type")
|
||||
}
|
||||
|
||||
// BlankError represents an error that occured when switching access.
|
||||
type BlankError struct{}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
func (e *BlankError) Error() string {
|
||||
return fmt.Sprint("You need to specify a namespace and a database to use")
|
||||
}
|
||||
|
||||
// ParseError represents an error that occurred during parsing.
|
||||
type ParseError struct {
|
||||
Found string
|
||||
|
|
12
sql/field.go
12
sql/field.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineFieldStatement() (stmt *DefineFieldStatement, err er
|
|||
|
||||
stmt = &DefineFieldStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
@ -137,9 +137,9 @@ func (p *parser) parseRemoveFieldStatement() (stmt *RemoveFieldStatement, err er
|
|||
|
||||
stmt = &RemoveFieldStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
|
12
sql/index.go
12
sql/index.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineIndexStatement() (stmt *DefineIndexStatement, err er
|
|||
|
||||
stmt = &DefineIndexStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
@ -56,9 +56,9 @@ func (p *parser) parseRemoveIndexStatement() (stmt *RemoveIndexStatement, err er
|
|||
|
||||
stmt = &RemoveIndexStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -18,9 +18,10 @@ func (p *parser) parseLetStatement() (stmt *LetStatement, err error) {
|
|||
|
||||
stmt = &LetStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The first part of a LET expression must
|
||||
// always be an identifier, specifying a
|
||||
// variable name to set.
|
||||
|
|
124
sql/options.go
Normal file
124
sql/options.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
// Copyright © 2016 Abcum Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"github.com/abcum/fibre"
|
||||
"github.com/abcum/surreal/cnf"
|
||||
)
|
||||
|
||||
const (
|
||||
// Root access
|
||||
AuthKV int = iota
|
||||
// Namespace access
|
||||
AuthNS
|
||||
// Database access
|
||||
AuthDB
|
||||
// Scoped user access
|
||||
AuthTB
|
||||
// No access
|
||||
AuthNO
|
||||
)
|
||||
|
||||
// options represents context runtime config.
|
||||
type options struct {
|
||||
kind int
|
||||
auth map[string]string
|
||||
conf map[string]string
|
||||
}
|
||||
|
||||
func newOptions(c *fibre.Context) *options {
|
||||
return &options{
|
||||
kind: c.Get("kind").(int),
|
||||
auth: c.Get("auth").(map[string]string),
|
||||
conf: c.Get("conf").(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *options) get(kind int) (kv, ns, db string, err error) {
|
||||
|
||||
kv = cnf.Settings.DB.Base
|
||||
ns = o.conf["NS"]
|
||||
db = o.conf["DB"]
|
||||
|
||||
if o.kind > kind {
|
||||
err = &QueryError{}
|
||||
return
|
||||
}
|
||||
|
||||
if ns == "" || db == "" {
|
||||
err = &BlankError{}
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (o *options) ns(ns string) (err error) {
|
||||
|
||||
// Check to see that the current user has
|
||||
// the necessary access level to perform
|
||||
// namespace switching / choosing.
|
||||
|
||||
if o.kind > AuthKV {
|
||||
return &EntryError{}
|
||||
}
|
||||
|
||||
// Check to see that the current user has
|
||||
// the necessary authentcation privileges
|
||||
// to be able to specify this namespace.
|
||||
|
||||
if o.auth["NS"] != "*" && o.auth["NS"] != ns {
|
||||
return &EntryError{}
|
||||
}
|
||||
|
||||
// Specify the NS on the context session, so
|
||||
// that it is remembered across requests on
|
||||
// any persistent connections.
|
||||
|
||||
o.conf["NS"] = ns
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (o *options) db(db string) (err error) {
|
||||
|
||||
// Check to see that the current user has
|
||||
// the necessary access level to perform
|
||||
// database switching / choosing.
|
||||
|
||||
if o.kind > AuthNS {
|
||||
return &EntryError{}
|
||||
}
|
||||
|
||||
// Check to see that the current user has
|
||||
// the necessary authentcation privileges
|
||||
// to be able to specify this namespace.
|
||||
|
||||
if o.auth["DB"] != "*" && o.auth["DB"] != db {
|
||||
return &EntryError{}
|
||||
}
|
||||
|
||||
// Specify the DB on the context session, so
|
||||
// that it is remembered across requests on
|
||||
// any persistent connections.
|
||||
|
||||
o.conf["DB"] = db
|
||||
|
||||
return
|
||||
|
||||
}
|
|
@ -25,6 +25,7 @@ import (
|
|||
// parser represents a parser.
|
||||
type parser struct {
|
||||
s *scanner
|
||||
o *options
|
||||
c *fibre.Context
|
||||
v map[string]interface{}
|
||||
buf struct {
|
||||
|
@ -37,7 +38,7 @@ type parser struct {
|
|||
|
||||
// newParser returns a new instance of Parser.
|
||||
func newParser(c *fibre.Context, v map[string]interface{}) *parser {
|
||||
return &parser{c: c, v: v}
|
||||
return &parser{c: c, v: v, o: newOptions(c)}
|
||||
}
|
||||
|
||||
// Parse parses sql from a []byte, string, or io.Reader.
|
||||
|
|
|
@ -18,9 +18,9 @@ func (p *parser) parseRelateStatement() (stmt *RelateStatement, err error) {
|
|||
|
||||
stmt = &RelateStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Type, err = p.parseTable(); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -18,7 +18,7 @@ func (p *parser) parseReturnStatement() (stmt *ReturnStatement, err error) {
|
|||
|
||||
stmt = &ReturnStatement{}
|
||||
|
||||
if _, _, _, err = p.o.get(AuthTB); err != nil {
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
12
sql/rules.go
12
sql/rules.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineRulesStatement() (stmt *DefineRulesStatement, err er
|
|||
|
||||
stmt = &DefineRulesStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, _, err = p.shouldBe(ON); err != nil {
|
||||
return nil, err
|
||||
|
@ -73,9 +73,9 @@ func (p *parser) parseRemoveRulesStatement() (stmt *RemoveRulesStatement, err er
|
|||
|
||||
stmt = &RemoveRulesStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, _, err = p.shouldBe(ON); err != nil {
|
||||
return nil, err
|
||||
|
|
12
sql/scope.go
12
sql/scope.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineScopeStatement() (stmt *DefineScopeStatement, err er
|
|||
|
||||
stmt = &DefineScopeStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
@ -46,9 +46,9 @@ func (p *parser) parseRemoveScopeStatement() (stmt *RemoveScopeStatement, err er
|
|||
|
||||
stmt = &RemoveScopeStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -18,9 +18,9 @@ func (p *parser) parseSelectStatement() (stmt *SelectStatement, err error) {
|
|||
|
||||
stmt = &SelectStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Expr, err = p.parseField(); err != nil {
|
||||
return nil, err
|
||||
|
|
12
sql/table.go
12
sql/table.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineTableStatement() (stmt *DefineTableStatement, err er
|
|||
|
||||
stmt = &DefineTableStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.What, err = p.parseNames(); err != nil {
|
||||
return nil, err
|
||||
|
@ -42,9 +42,9 @@ func (p *parser) parseRemoveTableStatement() (stmt *RemoveTableStatement, err er
|
|||
|
||||
stmt = &RemoveTableStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.What, err = p.parseNames(); err != nil {
|
||||
return nil, err
|
||||
|
|
12
sql/trans.go
12
sql/trans.go
|
@ -18,6 +18,10 @@ func (p *parser) parseBeginStatement() (stmt *BeginStatement, err error) {
|
|||
|
||||
stmt = &BeginStatement{}
|
||||
|
||||
if _, _, _, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _ = p.mightBe(TRANSACTION)
|
||||
|
||||
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {
|
||||
|
@ -32,6 +36,10 @@ func (p *parser) parseCancelStatement() (stmt *CancelStatement, err error) {
|
|||
|
||||
stmt = &CancelStatement{}
|
||||
|
||||
if _, _, _, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _ = p.mightBe(TRANSACTION)
|
||||
|
||||
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {
|
||||
|
@ -46,6 +54,10 @@ func (p *parser) parseCommitStatement() (stmt *CommitStatement, err error) {
|
|||
|
||||
stmt = &CommitStatement{}
|
||||
|
||||
if _, _, _, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _ = p.mightBe(TRANSACTION)
|
||||
|
||||
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {
|
||||
|
|
|
@ -18,9 +18,9 @@ func (p *parser) parseUpdateStatement() (stmt *UpdateStatement, err error) {
|
|||
|
||||
stmt = &UpdateStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _ = p.mightBe(INTO)
|
||||
|
||||
|
|
16
sql/use.go
16
sql/use.go
|
@ -29,21 +29,29 @@ func (p *parser) parseUseStatement() (stmt *UseStatement, err error) {
|
|||
for {
|
||||
|
||||
if p.is(tok, NAMESPACE) {
|
||||
|
||||
_, stmt.NS, err = p.shouldBe(IDENT, STRING)
|
||||
if err != nil {
|
||||
return nil, &ParseError{Found: stmt.NS, Expected: []string{"namespace name"}}
|
||||
}
|
||||
// TODO: need to make sure this user can access this NS
|
||||
p.c.Set("NS", stmt.NS)
|
||||
|
||||
if err = p.o.ns(stmt.NS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if p.is(tok, DATABASE) {
|
||||
|
||||
_, stmt.DB, err = p.shouldBe(IDENT, DATE, TIME, STRING, NUMBER, DOUBLE)
|
||||
if err != nil {
|
||||
return nil, &ParseError{Found: stmt.DB, Expected: []string{"database name"}}
|
||||
}
|
||||
// TODO: need to make sure this user can access this DB
|
||||
p.c.Set("DB", stmt.DB)
|
||||
|
||||
if err = p.o.db(stmt.DB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
tok, _, exi = p.mightBe(NAMESPACE, DATABASE)
|
||||
|
|
12
sql/view.go
12
sql/view.go
|
@ -18,9 +18,9 @@ func (p *parser) parseDefineViewStatement() (stmt *DefineViewStatement, err erro
|
|||
|
||||
stmt = &DefineViewStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
@ -69,9 +69,9 @@ func (p *parser) parseRemoveViewStatement() (stmt *RemoveViewStatement, err erro
|
|||
|
||||
stmt = &RemoveViewStatement{}
|
||||
|
||||
stmt.KV = p.c.Get("KV").(string)
|
||||
stmt.NS = p.c.Get("NS").(string)
|
||||
stmt.DB = p.c.Get("DB").(string)
|
||||
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stmt.Name, err = p.parseName(); err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Reference in a new issue