From 7556a77df0af2211498af45b17c4ae0ae68d14ec Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Tue, 18 Oct 2016 13:49:46 +0100 Subject: [PATCH] Improve transactions --- db/create.go | 17 +--------- db/db.go | 87 ++++++++++++++++++++++++++++++++---------------- db/define.go | 60 --------------------------------- db/delete.go | 15 --------- db/info.go | 8 ----- db/modify.go | 15 --------- db/remove.go | 60 --------------------------------- db/select.go | 10 +----- db/update.go | 15 --------- kvs/boltdb/tx.go | 9 ++++- kvs/mysql/tx.go | 4 +++ kvs/pgsql/tx.go | 4 +++ kvs/tx.go | 1 + 13 files changed, 78 insertions(+), 227 deletions(-) diff --git a/db/create.go b/db/create.go index a1ddb23e..ed0b0fd4 100644 --- a/db/create.go +++ b/db/create.go @@ -24,17 +24,6 @@ import ( func executeCreateStatement(txn kvs.TX, ast *sql.CreateStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { @@ -61,10 +50,6 @@ func executeCreateStatement(txn kvs.TX, ast *sql.CreateStatement) (out []interfa } - if local { - txn.Commit() - } - return } @@ -76,7 +61,7 @@ func create(doc *item.Doc, ast *sql.CreateStatement) (out interface{}, err error } if !doc.Allow("CREATE") { - return nil, nil + return } if err = doc.StoreIndex(); err != nil { diff --git a/db/db.go b/db/db.go index efff8df3..b98ffd78 100644 --- a/db/db.go +++ b/db/db.go @@ -111,13 +111,27 @@ func detail(e error) interface{} { } } +func writable(cur kvs.TX, tmp bool) (txn kvs.TX, err error, loc bool) { + if cur == nil { + cur, err = db.Txn(true) + } + return cur, err, tmp +} + +func readable(cur kvs.TX, tmp bool) (txn kvs.TX, err error, loc bool) { + if cur == nil { + cur, err = db.Txn(false) + } + return cur, err, tmp +} + func execute(ctx *fibre.Context, ast *sql.Query, chn chan<- interface{}) { var txn kvs.TX defer func() { if txn != nil { - txn.Rollback() + txn.Cancel() } if r := recover(); r != nil { if err, ok := r.(error); ok { @@ -130,15 +144,47 @@ func execute(ctx *fibre.Context, ast *sql.Query, chn chan<- interface{}) { for _, s := range ast.Statements { - var res []interface{} + var loc bool var err error + var res []interface{} now := time.Now() - switch stm := s.(type) { + switch s.(type) { case *sql.UseStatement: continue + case *sql.BeginStatement: + break + case *sql.CancelStatement: + break + case *sql.CommitStatement: + break + case *sql.InfoStatement: + txn, err, loc = readable(txn, txn == nil) + default: + txn, err, loc = writable(txn, txn == nil) + } + + if err != nil { + chn <- err + } + + switch stm := s.(type) { + + case *sql.CommitStatement: + txn.Commit() + txn = nil + continue + + case *sql.CancelStatement: + txn.Cancel() + txn = nil + continue + + case *sql.BeginStatement: + txn, err, loc = writable(txn, false) + continue case *sql.InfoStatement: res, err = executeInfoStatement(txn, stm) @@ -176,36 +222,21 @@ func execute(ctx *fibre.Context, ast *sql.Query, chn chan<- interface{}) { case *sql.RemoveIndexStatement: res, err = executeRemoveIndexStatement(txn, stm) - case *sql.BeginStatement: - if txn != nil { - chn <- fibre.NewHTTPError(400, "Transaction already running") - return - } else if txn, err = db.Txn(true); err != nil { - chn <- err - return - } - - case *sql.CommitStatement: - if txn != nil { - txn.Commit() - txn = nil - continue - } - - case *sql.CancelStatement: - if txn != nil { - txn.Rollback() - txn = nil - continue - } - } - if err != nil && txn != nil { - txn.Rollback() + if err != nil { chn <- err } + if loc { + if err != nil { + txn.Cancel() + } else { + txn.Commit() + } + txn = nil + } + chn <- &Response{ Time: time.Since(now).String(), Status: status(err), diff --git a/db/define.go b/db/define.go index 197fa59a..6a274530 100644 --- a/db/define.go +++ b/db/define.go @@ -24,17 +24,6 @@ import ( func executeDefineTableStatement(txn kvs.TX, ast *sql.DefineTableStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Set the namespace definition @@ -57,27 +46,12 @@ func executeDefineTableStatement(txn kvs.TX, ast *sql.DefineTableStatement) (out } - if local { - txn.Commit() - } - return } func executeDefineRulesStatement(txn kvs.TX, ast *sql.DefineRulesStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { for _, RU := range ast.When { @@ -110,27 +84,12 @@ func executeDefineRulesStatement(txn kvs.TX, ast *sql.DefineRulesStatement) (out } - if local { - txn.Commit() - } - return } func executeDefineFieldStatement(txn kvs.TX, ast *sql.DefineFieldStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Set the namespace definition @@ -159,27 +118,12 @@ func executeDefineFieldStatement(txn kvs.TX, ast *sql.DefineFieldStatement) (out } - if local { - txn.Commit() - } - return } func executeDefineIndexStatement(txn kvs.TX, ast *sql.DefineIndexStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Set the namespace definition @@ -226,10 +170,6 @@ func executeDefineIndexStatement(txn kvs.TX, ast *sql.DefineIndexStatement) (out } - if local { - txn.Commit() - } - return } diff --git a/db/delete.go b/db/delete.go index e5847135..a4ce3911 100644 --- a/db/delete.go +++ b/db/delete.go @@ -23,17 +23,6 @@ import ( func executeDeleteStatement(txn kvs.TX, ast *sql.DeleteStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { @@ -63,10 +52,6 @@ func executeDeleteStatement(txn kvs.TX, ast *sql.DeleteStatement) (out []interfa } - if local { - txn.Commit() - } - return } diff --git a/db/info.go b/db/info.go index 099ce0d1..d3a19371 100644 --- a/db/info.go +++ b/db/info.go @@ -25,14 +25,6 @@ import ( func executeInfoStatement(txn kvs.TX, ast *sql.InfoStatement) (out []interface{}, err error) { - if txn == nil { - txn, err = db.Txn(false) - if err != nil { - return - } - defer txn.Rollback() - } - if ast.What == "" { res := data.New() diff --git a/db/modify.go b/db/modify.go index 8f6323ae..5280dad5 100644 --- a/db/modify.go +++ b/db/modify.go @@ -23,17 +23,6 @@ import ( func executeModifyStatement(txn kvs.TX, ast *sql.ModifyStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { @@ -63,10 +52,6 @@ func executeModifyStatement(txn kvs.TX, ast *sql.ModifyStatement) (out []interfa } - if local { - txn.Commit() - } - return } diff --git a/db/remove.go b/db/remove.go index 2f9f38a7..4cf098ae 100644 --- a/db/remove.go +++ b/db/remove.go @@ -22,17 +22,6 @@ import ( func executeRemoveTableStatement(txn kvs.TX, ast *sql.RemoveTableStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Remove the table config @@ -67,27 +56,12 @@ func executeRemoveTableStatement(txn kvs.TX, ast *sql.RemoveTableStatement) (out } - if local { - txn.Commit() - } - return } func executeRemoveRulesStatement(txn kvs.TX, ast *sql.RemoveRulesStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { for _, RU := range ast.When { @@ -102,27 +76,12 @@ func executeRemoveRulesStatement(txn kvs.TX, ast *sql.RemoveRulesStatement) (out } - if local { - txn.Commit() - } - return } func executeRemoveFieldStatement(txn kvs.TX, ast *sql.RemoveFieldStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Remove the field config @@ -133,27 +92,12 @@ func executeRemoveFieldStatement(txn kvs.TX, ast *sql.RemoveFieldStatement) (out } - if local { - txn.Commit() - } - return } func executeRemoveIndexStatement(txn kvs.TX, ast *sql.RemoveIndexStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, TB := range ast.What { // Remove the index config @@ -170,10 +114,6 @@ func executeRemoveIndexStatement(txn kvs.TX, ast *sql.RemoveIndexStatement) (out } - if local { - txn.Commit() - } - return } diff --git a/db/select.go b/db/select.go index 3543fb2c..7a678ec5 100644 --- a/db/select.go +++ b/db/select.go @@ -23,14 +23,6 @@ import ( func executeSelectStatement(txn kvs.TX, ast *sql.SelectStatement) (out []interface{}, err error) { - if txn == nil { - txn, err = db.Txn(false) - if err != nil { - return - } - defer txn.Close() - } - for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { @@ -71,7 +63,7 @@ func detect(doc *item.Doc, ast *sql.SelectStatement) (out interface{}, err error } if !doc.Allow("SELECT") { - return nil, nil + return } out = doc.Blaze(ast) diff --git a/db/update.go b/db/update.go index e062d71e..e64a534b 100644 --- a/db/update.go +++ b/db/update.go @@ -23,17 +23,6 @@ import ( func executeUpdateStatement(txn kvs.TX, ast *sql.UpdateStatement) (out []interface{}, err error) { - var local bool - - if txn == nil { - local = true - txn, err = db.Txn(true) - if err != nil { - return - } - defer txn.Rollback() - } - for _, w := range ast.What { if what, ok := w.(*sql.Thing); ok { @@ -63,10 +52,6 @@ func executeUpdateStatement(txn kvs.TX, ast *sql.UpdateStatement) (out []interfa } - if local { - txn.Commit() - } - return } diff --git a/kvs/boltdb/tx.go b/kvs/boltdb/tx.go index 13b08eb2..e1b03f68 100644 --- a/kvs/boltdb/tx.go +++ b/kvs/boltdb/tx.go @@ -333,8 +333,15 @@ func (tx *TX) Close() (err error) { return tx.Rollback() } +func (tx *TX) Cancel() (err error) { + return tx.Rollback() +} + func (tx *TX) Commit() (err error) { - return tx.tx.Commit() + if tx.tx.Writable() { + return tx.tx.Commit() + } + return tx.tx.Rollback() } func (tx *TX) Rollback() (err error) { diff --git a/kvs/mysql/tx.go b/kvs/mysql/tx.go index 2f4ffdd3..d3583add 100644 --- a/kvs/mysql/tx.go +++ b/kvs/mysql/tx.go @@ -314,6 +314,10 @@ func (tx *TX) Close() (err error) { return tx.Rollback() } +func (tx *TX) Cancel() (err error) { + return tx.Rollback() +} + func (tx *TX) Commit() (err error) { return tx.tx.Commit() } diff --git a/kvs/pgsql/tx.go b/kvs/pgsql/tx.go index 3ed08871..623a1e4e 100644 --- a/kvs/pgsql/tx.go +++ b/kvs/pgsql/tx.go @@ -314,6 +314,10 @@ func (tx *TX) Close() (err error) { return tx.Rollback() } +func (tx *TX) Cancel() (err error) { + return tx.Rollback() +} + func (tx *TX) Commit() (err error) { return tx.tx.Commit() } diff --git a/kvs/tx.go b/kvs/tx.go index b2bf1250..d003e325 100644 --- a/kvs/tx.go +++ b/kvs/tx.go @@ -29,6 +29,7 @@ type TX interface { PDel([]byte) error RDel([]byte, []byte, uint64) error Close() error + Cancel() error Commit() error Rollback() error }