From dc6a357e26c30a880572b663e9610632f835fb46 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Tue, 6 Feb 2018 17:07:42 +0000 Subject: [PATCH] Enable passing context when beginning a transaction --- db/db.go | 2 +- db/executor.go | 8 ++++---- db/socket.go | 2 +- kvs/db.go | 7 +++++-- kvs/ds.go | 5 +++-- kvs/rixxdb/db.go | 4 +++- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/db/db.go b/db/db.go index 18f75d0b..83e1121b 100644 --- a/db/db.go +++ b/db/db.go @@ -83,7 +83,7 @@ func Export(w io.Writer) (err error) { // with the underlying database, and returns // the transaction, or any error which occured. func Begin(rw bool) (txn kvs.TX, err error) { - return db.Begin(rw) + return db.Begin(context.Background(), rw) } // Socket registers a websocket for live queries diff --git a/db/executor.go b/db/executor.go index 9c554779..c2bb6966 100644 --- a/db/executor.go +++ b/db/executor.go @@ -145,7 +145,7 @@ func (e *executor) execute(ctx context.Context, ast *sql.Query) { switch stm.(type) { case *sql.BeginStatement: - err = e.begin(true) + err = e.begin(ctx, true) trc.Finish() continue case *sql.CancelStatement: @@ -243,7 +243,7 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf trw = false } - err = e.begin(trw) + err = e.begin(ctx, trw) if err != nil { return } @@ -420,9 +420,9 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf } -func (e *executor) begin(rw bool) (err error) { +func (e *executor) begin(ctx context.Context, rw bool) (err error) { if e.dbo.TX == nil { - e.dbo.TX, err = db.Begin(rw) + e.dbo.TX, err = db.Begin(ctx, rw) } return } diff --git a/db/socket.go b/db/socket.go index 2f57b829..cd9a4a12 100644 --- a/db/socket.go +++ b/db/socket.go @@ -244,7 +244,7 @@ func (s *socket) deregister(id string) { delete(sockets, id) - txn, _ := db.Begin(true) + txn, _ := db.Begin(context.Background(), true) defer txn.Commit() diff --git a/kvs/db.go b/kvs/db.go index e1bd9f8b..d3562077 100644 --- a/kvs/db.go +++ b/kvs/db.go @@ -14,11 +14,14 @@ package kvs -import "io" +import ( + "context" + "io" +) // DB represents a database implementation type DB interface { - Begin(bool) (TX, error) + Begin(context.Context, bool) (TX, error) Import(io.Reader) error Export(io.Writer) error Close() error diff --git a/kvs/ds.go b/kvs/ds.go index 3cee84e6..81a8923f 100644 --- a/kvs/ds.go +++ b/kvs/ds.go @@ -15,6 +15,7 @@ package kvs import ( + "context" "io" "strings" @@ -63,8 +64,8 @@ func New(opts *cnf.Options) (ds *DS, err error) { // Begin begins a new read / write transaction // with the underlying database, and returns // the transaction, or any error which occured. -func (ds *DS) Begin(writable bool) (txn TX, err error) { - return ds.db.Begin(writable) +func (ds *DS) Begin(ctx context.Context, writable bool) (txn TX, err error) { + return ds.db.Begin(ctx, writable) } // Import loads database operations from a reader. diff --git a/kvs/rixxdb/db.go b/kvs/rixxdb/db.go index 2d7e5ef0..af27eddf 100644 --- a/kvs/rixxdb/db.go +++ b/kvs/rixxdb/db.go @@ -17,6 +17,8 @@ package rixxdb import ( "io" + "context" + "github.com/abcum/rixxdb" "github.com/abcum/surreal/kvs" ) @@ -25,7 +27,7 @@ type DB struct { pntr *rixxdb.DB } -func (db *DB) Begin(writable bool) (txn kvs.TX, err error) { +func (db *DB) Begin(ctx context.Context, writable bool) (txn kvs.TX, err error) { var pntr *rixxdb.TX if pntr, err = db.pntr.Begin(writable); err != nil { err = &kvs.DBError{Err: err}