diff --git a/db/db.go b/db/db.go index 6f8dc6d9..318cd2d2 100644 --- a/db/db.go +++ b/db/db.go @@ -338,6 +338,14 @@ func (e *executor) operate(ast sql.Statement) (res []interface{}, err error) { } + // Mark the beginning of this statement so we + // can monitor the running time, and ensure + // it runs no longer than specified. + + if stm, ok := ast.(sql.KillableStatement); ok { + stm.Begin() + } + // Execute the defined statement, receiving the // result set, and any errors which occured // while processing the query. @@ -419,6 +427,14 @@ func (e *executor) operate(ast sql.Statement) (res []interface{}, err error) { } } + // The statement has successfully cancelled + // or committed, so stop all the transaction + // timeout timers if any were set. + + if stm, ok := ast.(sql.KillableStatement); ok { + stm.Cease() + } + return } diff --git a/sql/ast.go b/sql/ast.go index d7e4961d..523d77cf 100644 --- a/sql/ast.go +++ b/sql/ast.go @@ -34,6 +34,21 @@ type Statement interface{} // Statements represents multiple SQL ASTs type Statements []Statement +// -------------------------------------------------- +// Other +// -------------------------------------------------- + +type KillableStatement interface { + Begin() + Cease() + Timedout() <-chan struct{} +} + +type killable struct { + ticker *time.Timer + closer chan struct{} +} + // -------------------------------------------------- // Trans // -------------------------------------------------- @@ -108,63 +123,73 @@ type LiveStatement struct { // SelectStatement represents a SQL SELECT statement. type SelectStatement struct { - KV string `cork:"-" codec:"-"` - NS string `cork:"-" codec:"-"` - DB string `cork:"-" codec:"-"` - Expr Fields `cork:"expr" codec:"expr"` - What Exprs `cork:"what" codec:"what"` - Cond Expr `cork:"cond" codec:"cond"` - Group Groups `cork:"group" codec:"group"` - Order Orders `cork:"order" codec:"order"` - Limit Expr `cork:"limit" codec:"limit"` - Start Expr `cork:"start" codec:"start"` - Version Expr `cork:"version" codec:"version"` + killable + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + Expr Fields `cork:"expr" codec:"expr"` + What Exprs `cork:"what" codec:"what"` + Cond Expr `cork:"cond" codec:"cond"` + Group Groups `cork:"group" codec:"group"` + Order Orders `cork:"order" codec:"order"` + Limit Expr `cork:"limit" codec:"limit"` + Start Expr `cork:"start" codec:"start"` + Version Expr `cork:"version" codec:"version"` + Timeout time.Duration `cork:"timeout" codec:"timeout"` } // CreateStatement represents a SQL CREATE statement. type CreateStatement struct { - KV string `cork:"-" codec:"-"` - NS string `cork:"-" codec:"-"` - DB string `cork:"-" codec:"-"` - What Exprs `cork:"what" codec:"what"` - Data Expr `cork:"data" codec:"data"` - Echo Token `cork:"echo" codec:"echo"` + killable + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + What Exprs `cork:"what" codec:"what"` + Data Expr `cork:"data" codec:"data"` + Echo Token `cork:"echo" codec:"echo"` + Timeout time.Duration `cork:"timeout" codec:"timeout"` } // UpdateStatement represents a SQL UPDATE statement. type UpdateStatement struct { - KV string `cork:"-" codec:"-"` - NS string `cork:"-" codec:"-"` - DB string `cork:"-" codec:"-"` - Hard bool `cork:"hard" codec:"hard"` - What Exprs `cork:"what" codec:"what"` - Data Expr `cork:"data" codec:"data"` - Cond Expr `cork:"cond" codec:"cond"` - Echo Token `cork:"echo" codec:"echo"` + killable + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + Hard bool `cork:"hard" codec:"hard"` + What Exprs `cork:"what" codec:"what"` + Data Expr `cork:"data" codec:"data"` + Cond Expr `cork:"cond" codec:"cond"` + Echo Token `cork:"echo" codec:"echo"` + Timeout time.Duration `cork:"timeout" codec:"timeout"` } // DeleteStatement represents a SQL DELETE statement. type DeleteStatement struct { - KV string `cork:"-" codec:"-"` - NS string `cork:"-" codec:"-"` - DB string `cork:"-" codec:"-"` - Hard bool `cork:"hard" codec:"hard"` - What Exprs `cork:"what" codec:"what"` - Cond Expr `cork:"cond" codec:"cond"` - Echo Token `cork:"echo" codec:"echo"` + killable + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + Hard bool `cork:"hard" codec:"hard"` + What Exprs `cork:"what" codec:"what"` + Cond Expr `cork:"cond" codec:"cond"` + Echo Token `cork:"echo" codec:"echo"` + Timeout time.Duration `cork:"timeout" codec:"timeout"` } // RelateStatement represents a SQL RELATE statement. type RelateStatement struct { - KV string `cork:"-" codec:"-"` - NS string `cork:"-" codec:"-"` - DB string `cork:"-" codec:"-"` - Type Expr `cork:"type" codec:"type"` - From Exprs `cork:"from" codec:"from"` - With Exprs `cork:"with" codec:"with"` - Data Expr `cork:"data" codec:"data"` - Uniq bool `cork:"uniq" codec:"uniq"` - Echo Token `cork:"echo" codec:"echo"` + killable + KV string `cork:"-" codec:"-"` + NS string `cork:"-" codec:"-"` + DB string `cork:"-" codec:"-"` + Type Expr `cork:"type" codec:"type"` + From Exprs `cork:"from" codec:"from"` + With Exprs `cork:"with" codec:"with"` + Data Expr `cork:"data" codec:"data"` + Uniq bool `cork:"uniq" codec:"uniq"` + Echo Token `cork:"echo" codec:"echo"` + Timeout time.Duration `cork:"timeout" codec:"timeout"` } // -------------------------------------------------- diff --git a/sql/create.go b/sql/create.go index 2adc0c9a..1a100f64 100644 --- a/sql/create.go +++ b/sql/create.go @@ -34,6 +34,10 @@ func (p *parser) parseCreateStatement() (stmt *CreateStatement, err error) { return nil, err } + if stmt.Timeout, err = p.parseTimeout(); err != nil { + return nil, err + } + if _, _, err = p.shouldBe(EOF, RPAREN, SEMICOLON); err != nil { return nil, err } diff --git a/sql/delete.go b/sql/delete.go index 5923bb30..9c0676ef 100644 --- a/sql/delete.go +++ b/sql/delete.go @@ -43,6 +43,10 @@ func (p *parser) parseDeleteStatement() (stmt *DeleteStatement, err error) { return nil, err } + if stmt.Timeout, err = p.parseTimeout(); err != nil { + return nil, err + } + if _, _, err = p.shouldBe(EOF, RPAREN, SEMICOLON); err != nil { return nil, err } diff --git a/sql/exprs.go b/sql/exprs.go index 6f47a888..1f5c81bc 100644 --- a/sql/exprs.go +++ b/sql/exprs.go @@ -367,6 +367,16 @@ func (p *parser) parseDuration() (time.Duration, error) { } +func (p *parser) parseTimeout() (time.Duration, error) { + + if _, _, exi := p.mightBe(TIMEOUT); !exi { + return 0, nil + } + + return p.parseDuration() + +} + func (p *parser) parseBcrypt() ([]byte, error) { _, lit, err := p.shouldBe(STRING) diff --git a/sql/gen.go b/sql/gen.go index 90861472..c6ada97c 100644 --- a/sql/gen.go +++ b/sql/gen.go @@ -14,4 +14,7 @@ package sql +//go:generate go get -u github.com/abcum/tmpl +//go:generate tmpl -file=kill.gen.json kill.gen.go.tmpl + //go:generate codecgen -o ast.gen.go ast.go diff --git a/sql/kill.gen.go.tmpl b/sql/kill.gen.go.tmpl new file mode 100644 index 00000000..71a794ae --- /dev/null +++ b/sql/kill.gen.go.tmpl @@ -0,0 +1,59 @@ +// Copyright © 2016 Abcum Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "time" +) + +{{with $types := .}}{{range $k := $types}} + +func (s *{{$k.name}}Statement) Begin() { + if s.Timeout == 0 { + return + } + if s.killable.closer == nil { + s.killable.closer = make(chan struct{}) + } + s.killable.ticker = time.AfterFunc(s.Timeout, func() { + s.killable.ticker.Stop() + s.killable.ticker = nil + close(s.killable.closer) + }) +} + +func (s *{{$k.name}}Statement) Cease() { + if s.Timeout == 0 { + return + } + if s.killable.closer == nil { + s.killable.closer = make(chan struct{}) + } + if s.killable.ticker != nil { + s.killable.ticker.Stop() + } +} + +func (s *{{$k.name}}Statement) Timedout() <-chan struct{} { + if s.Timeout == 0 { + return nil + } + if s.killable.closer == nil { + s.killable.closer = make(chan struct{}) + } + return s.killable.closer +} + +{{end}}{{end}} diff --git a/sql/kill.gen.json b/sql/kill.gen.json new file mode 100644 index 00000000..c0c45a41 --- /dev/null +++ b/sql/kill.gen.json @@ -0,0 +1,7 @@ +[ + { "name": "Select" }, + { "name": "Create" }, + { "name": "Update" }, + { "name": "Delete" }, + { "name": "Relate" } +] diff --git a/sql/relate.go b/sql/relate.go index b5f0b17f..89fd4b56 100644 --- a/sql/relate.go +++ b/sql/relate.go @@ -52,6 +52,10 @@ func (p *parser) parseRelateStatement() (stmt *RelateStatement, err error) { return nil, err } + if stmt.Timeout, err = p.parseTimeout(); err != nil { + return nil, err + } + if _, _, err = p.shouldBe(EOF, RPAREN, SEMICOLON); err != nil { return nil, err } diff --git a/sql/scanner.go b/sql/scanner.go index aba79804..ab64ddce 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -591,6 +591,9 @@ func (s *scanner) scanNumber(chp ...rune) (tok Token, lit string, val interface{ if chn := s.next(); chn == 's' { tok = DURATION buf.WriteRune(chn) + } else if ch == 'm' { + tok = DURATION + s.undo() } else { s.undo() } diff --git a/sql/select.go b/sql/select.go index 45d8194b..4d597eb9 100644 --- a/sql/select.go +++ b/sql/select.go @@ -59,6 +59,10 @@ func (p *parser) parseSelectStatement() (stmt *SelectStatement, err error) { return nil, err } + if stmt.Timeout, err = p.parseTimeout(); err != nil { + return nil, err + } + if _, _, err = p.shouldBe(EOF, RPAREN, SEMICOLON); err != nil { return nil, err } diff --git a/sql/string.go b/sql/string.go index 89412f55..76bca71f 100644 --- a/sql/string.go +++ b/sql/string.go @@ -18,6 +18,7 @@ import ( "encoding/json" "fmt" "strings" + "time" ) func orNil(v interface{}) string { @@ -147,6 +148,15 @@ func stringFromInterface(v interface{}, y, n string) string { } } +func stringFromDuration(v time.Duration, y, n string) string { + switch v { + case 0: + return n + default: + return y + } +} + // --------------------------------------------- // Statements // --------------------------------------------- @@ -199,7 +209,7 @@ func (this ReturnStatement) String() string { } func (this SelectStatement) String() string { - return fmt.Sprintf("SELECT %v FROM %v%v%v%v%v%v%v", + return fmt.Sprintf("SELECT %v FROM %v%v%v%v%v%v%v%v", this.Expr, this.What, stringFromInterface(this.Cond, fmt.Sprintf(" WHERE %v", this.Cond), ""), @@ -208,43 +218,48 @@ func (this SelectStatement) String() string { stringFromInterface(this.Limit, fmt.Sprintf(" LIMIT %v", this.Limit), ""), stringFromInterface(this.Start, fmt.Sprintf(" START %v", this.Start), ""), stringFromInterface(this.Version, fmt.Sprintf(" VERSION %v", this.Version), ""), + stringFromDuration(this.Timeout, fmt.Sprintf(" TIMEOUT %v", this.Timeout.String()), ""), ) } func (this CreateStatement) String() string { - return fmt.Sprintf("CREATE %v%v RETURN %v", + return fmt.Sprintf("CREATE %v%v RETURN %v%v", this.What, this.Data, this.Echo, + stringFromDuration(this.Timeout, fmt.Sprintf(" TIMEOUT %v", this.Timeout.String()), ""), ) } func (this UpdateStatement) String() string { - return fmt.Sprintf("CREATE %v%v%v RETURN %v", + return fmt.Sprintf("CREATE %v%v%v RETURN %v%v", this.What, this.Data, this.Cond, this.Echo, + stringFromDuration(this.Timeout, fmt.Sprintf(" TIMEOUT %v", this.Timeout.String()), ""), ) } func (this DeleteStatement) String() string { - return fmt.Sprintf("DELETE %v%v%v RETURN %v", + return fmt.Sprintf("DELETE %v%v%v RETURN %v%v", stringFromBool(this.Hard, "AND EXPUNGE ", ""), this.What, this.Cond, this.Echo, + stringFromDuration(this.Timeout, fmt.Sprintf(" TIMEOUT %v", this.Timeout.String()), ""), ) } func (this RelateStatement) String() string { - return fmt.Sprintf("RELATE %v FROM %v WITH %v%v%v RETURN %v", + return fmt.Sprintf("RELATE %v FROM %v WITH %v%v%v RETURN %v%v", this.Type, this.From, this.With, this.Data, stringFromBool(this.Uniq, " UNIQUE", ""), this.Echo, + stringFromDuration(this.Timeout, fmt.Sprintf(" TIMEOUT %v", this.Timeout.String()), ""), ) } diff --git a/sql/tokens.go b/sql/tokens.go index 85ca4102..4a8d3974 100644 --- a/sql/tokens.go +++ b/sql/tokens.go @@ -180,6 +180,7 @@ const ( SOMECONTAINEDIN START TABLE + TIMEOUT TO TOKEN TRANSACTION @@ -346,6 +347,7 @@ var tokens = [...]string{ SOMECONTAINEDIN: "SOMECONTAINEDIN", START: "START", TABLE: "TABLE", + TIMEOUT: "TIMEOUT", TO: "TO", TOKEN: "TOKEN", TRANSACTION: "TRANSACTION", diff --git a/sql/update.go b/sql/update.go index 5e421432..8cefdf21 100644 --- a/sql/update.go +++ b/sql/update.go @@ -45,6 +45,10 @@ func (p *parser) parseUpdateStatement() (stmt *UpdateStatement, err error) { return nil, err } + if stmt.Timeout, err = p.parseTimeout(); err != nil { + return nil, err + } + if _, _, err = p.shouldBe(EOF, RPAREN, SEMICOLON); err != nil { return nil, err }