Perform authentication access checks in SQL layer

This commit is contained in:
Tobie Morgan Hitchcock 2016-11-04 11:25:53 +00:00
parent df7ee71cf6
commit b720213bd4
18 changed files with 230 additions and 64 deletions

View file

@ -18,11 +18,9 @@ func (p *parser) parseCreateStatement() (stmt *CreateStatement, err error) {
stmt = &CreateStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
_, _, _ = p.mightBe(INTO)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
if stmt.What, err = p.parseWhat(); err != nil {
return nil, err

View file

@ -18,11 +18,9 @@ func (p *parser) parseDeleteStatement() (stmt *DeleteStatement, err error) {
stmt = &DeleteStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
_, _, _ = p.mightBe(AND)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
_, _, stmt.Hard = p.mightBe(EXPUNGE)

View file

@ -27,6 +27,30 @@ func (e *EmptyError) Error() string {
return fmt.Sprint("Your SQL query is empty")
}
// EntryError represents an error that occured when switching access.
type EntryError struct{}
// Error returns the string representation of the error.
func (e *EntryError) Error() string {
return fmt.Sprint("You don't have permission to access this resource")
}
// QueryError represents an error that occured when switching access.
type QueryError struct{}
// Error returns the string representation of the error.
func (e *QueryError) Error() string {
return fmt.Sprint("You don't have permission to perform this query type")
}
// BlankError represents an error that occured when switching access.
type BlankError struct{}
// Error returns the string representation of the error.
func (e *BlankError) Error() string {
return fmt.Sprint("You need to specify a namespace and a database to use")
}
// ParseError represents an error that occurred during parsing.
type ParseError struct {
Found string

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineFieldStatement() (stmt *DefineFieldStatement, err er
stmt = &DefineFieldStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err
@ -137,9 +137,9 @@ func (p *parser) parseRemoveFieldStatement() (stmt *RemoveFieldStatement, err er
stmt = &RemoveFieldStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineIndexStatement() (stmt *DefineIndexStatement, err er
stmt = &DefineIndexStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err
@ -56,9 +56,9 @@ func (p *parser) parseRemoveIndexStatement() (stmt *RemoveIndexStatement, err er
stmt = &RemoveIndexStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err

View file

@ -18,9 +18,10 @@ func (p *parser) parseLetStatement() (stmt *LetStatement, err error) {
stmt = &LetStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
// The first part of a LET expression must
// always be an identifier, specifying a
// variable name to set.

124
sql/options.go Normal file
View file

@ -0,0 +1,124 @@
// Copyright © 2016 Abcum Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"github.com/abcum/fibre"
"github.com/abcum/surreal/cnf"
)
const (
// Root access
AuthKV int = iota
// Namespace access
AuthNS
// Database access
AuthDB
// Scoped user access
AuthTB
// No access
AuthNO
)
// options represents context runtime config.
type options struct {
kind int
auth map[string]string
conf map[string]string
}
func newOptions(c *fibre.Context) *options {
return &options{
kind: c.Get("kind").(int),
auth: c.Get("auth").(map[string]string),
conf: c.Get("conf").(map[string]string),
}
}
func (o *options) get(kind int) (kv, ns, db string, err error) {
kv = cnf.Settings.DB.Base
ns = o.conf["NS"]
db = o.conf["DB"]
if o.kind > kind {
err = &QueryError{}
return
}
if ns == "" || db == "" {
err = &BlankError{}
return
}
return
}
func (o *options) ns(ns string) (err error) {
// Check to see that the current user has
// the necessary access level to perform
// namespace switching / choosing.
if o.kind > AuthKV {
return &EntryError{}
}
// Check to see that the current user has
// the necessary authentcation privileges
// to be able to specify this namespace.
if o.auth["NS"] != "*" && o.auth["NS"] != ns {
return &EntryError{}
}
// Specify the NS on the context session, so
// that it is remembered across requests on
// any persistent connections.
o.conf["NS"] = ns
return
}
func (o *options) db(db string) (err error) {
// Check to see that the current user has
// the necessary access level to perform
// database switching / choosing.
if o.kind > AuthNS {
return &EntryError{}
}
// Check to see that the current user has
// the necessary authentcation privileges
// to be able to specify this namespace.
if o.auth["DB"] != "*" && o.auth["DB"] != db {
return &EntryError{}
}
// Specify the DB on the context session, so
// that it is remembered across requests on
// any persistent connections.
o.conf["DB"] = db
return
}

View file

@ -25,6 +25,7 @@ import (
// parser represents a parser.
type parser struct {
s *scanner
o *options
c *fibre.Context
v map[string]interface{}
buf struct {
@ -37,7 +38,7 @@ type parser struct {
// newParser returns a new instance of Parser.
func newParser(c *fibre.Context, v map[string]interface{}) *parser {
return &parser{c: c, v: v}
return &parser{c: c, v: v, o: newOptions(c)}
}
// Parse parses sql from a []byte, string, or io.Reader.

View file

@ -18,9 +18,9 @@ func (p *parser) parseRelateStatement() (stmt *RelateStatement, err error) {
stmt = &RelateStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
if stmt.Type, err = p.parseTable(); err != nil {
return nil, err

View file

@ -18,7 +18,7 @@ func (p *parser) parseReturnStatement() (stmt *ReturnStatement, err error) {
stmt = &ReturnStatement{}
if _, _, _, err = p.o.get(AuthTB); err != nil {
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineRulesStatement() (stmt *DefineRulesStatement, err er
stmt = &DefineRulesStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if _, _, err = p.shouldBe(ON); err != nil {
return nil, err
@ -73,9 +73,9 @@ func (p *parser) parseRemoveRulesStatement() (stmt *RemoveRulesStatement, err er
stmt = &RemoveRulesStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if _, _, err = p.shouldBe(ON); err != nil {
return nil, err

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineScopeStatement() (stmt *DefineScopeStatement, err er
stmt = &DefineScopeStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err
@ -46,9 +46,9 @@ func (p *parser) parseRemoveScopeStatement() (stmt *RemoveScopeStatement, err er
stmt = &RemoveScopeStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err

View file

@ -18,9 +18,9 @@ func (p *parser) parseSelectStatement() (stmt *SelectStatement, err error) {
stmt = &SelectStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
if stmt.Expr, err = p.parseField(); err != nil {
return nil, err

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineTableStatement() (stmt *DefineTableStatement, err er
stmt = &DefineTableStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.What, err = p.parseNames(); err != nil {
return nil, err
@ -42,9 +42,9 @@ func (p *parser) parseRemoveTableStatement() (stmt *RemoveTableStatement, err er
stmt = &RemoveTableStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.What, err = p.parseNames(); err != nil {
return nil, err

View file

@ -18,6 +18,10 @@ func (p *parser) parseBeginStatement() (stmt *BeginStatement, err error) {
stmt = &BeginStatement{}
if _, _, _, err = p.o.get(AuthTB); err != nil {
return nil, err
}
_, _, _ = p.mightBe(TRANSACTION)
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {
@ -32,6 +36,10 @@ func (p *parser) parseCancelStatement() (stmt *CancelStatement, err error) {
stmt = &CancelStatement{}
if _, _, _, err = p.o.get(AuthTB); err != nil {
return nil, err
}
_, _, _ = p.mightBe(TRANSACTION)
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {
@ -46,6 +54,10 @@ func (p *parser) parseCommitStatement() (stmt *CommitStatement, err error) {
stmt = &CommitStatement{}
if _, _, _, err = p.o.get(AuthTB); err != nil {
return nil, err
}
_, _, _ = p.mightBe(TRANSACTION)
if _, _, err = p.shouldBe(EOF, SEMICOLON); err != nil {

View file

@ -18,9 +18,9 @@ func (p *parser) parseUpdateStatement() (stmt *UpdateStatement, err error) {
stmt = &UpdateStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthTB); err != nil {
return nil, err
}
_, _, _ = p.mightBe(INTO)

View file

@ -29,21 +29,29 @@ func (p *parser) parseUseStatement() (stmt *UseStatement, err error) {
for {
if p.is(tok, NAMESPACE) {
_, stmt.NS, err = p.shouldBe(IDENT, STRING)
if err != nil {
return nil, &ParseError{Found: stmt.NS, Expected: []string{"namespace name"}}
}
// TODO: need to make sure this user can access this NS
p.c.Set("NS", stmt.NS)
if err = p.o.ns(stmt.NS); err != nil {
return nil, err
}
}
if p.is(tok, DATABASE) {
_, stmt.DB, err = p.shouldBe(IDENT, DATE, TIME, STRING, NUMBER, DOUBLE)
if err != nil {
return nil, &ParseError{Found: stmt.DB, Expected: []string{"database name"}}
}
// TODO: need to make sure this user can access this DB
p.c.Set("DB", stmt.DB)
if err = p.o.db(stmt.DB); err != nil {
return nil, err
}
}
tok, _, exi = p.mightBe(NAMESPACE, DATABASE)

View file

@ -18,9 +18,9 @@ func (p *parser) parseDefineViewStatement() (stmt *DefineViewStatement, err erro
stmt = &DefineViewStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err
@ -69,9 +69,9 @@ func (p *parser) parseRemoveViewStatement() (stmt *RemoveViewStatement, err erro
stmt = &RemoveViewStatement{}
stmt.KV = p.c.Get("KV").(string)
stmt.NS = p.c.Get("NS").(string)
stmt.DB = p.c.Get("DB").(string)
if stmt.KV, stmt.NS, stmt.DB, err = p.o.get(AuthDB); err != nil {
return nil, err
}
if stmt.Name, err = p.parseName(); err != nil {
return nil, err