diff --git a/cli/start.go b/cli/start.go index f038b7af..e639d57a 100644 --- a/cli/start.go +++ b/cli/start.go @@ -82,6 +82,8 @@ func init() { startCmd.PersistentFlags().IntVar(&opts.Port, "port", 8000, "The port on which to serve the web server") startCmd.PersistentFlags().StringVarP(&opts.Bind, "bind", "b", "0.0.0.0", "The hostname or ip address to listen for connections on") + startCmd.PersistentFlags().DurationVar(&opts.Query.Timeout, "timeout", 0, "") + startCmd.PersistentFlags().StringVarP(&opts.DB.Code, "key", "k", "", "Encryption key to use for on-disk encryption") startCmd.PersistentFlags().DurationVar(&opts.DB.Proc.Flush, "db-flush", 0, "A time duration to use when syncing data to persistent storage") diff --git a/cnf/cnf.go b/cnf/cnf.go index 0087caac..107a7437 100644 --- a/cnf/cnf.go +++ b/cnf/cnf.go @@ -59,6 +59,10 @@ type Options struct { Nets []*net.IPNet // Allowed cidr ranges for authentication } + Query struct { + Timeout time.Duration // Fixed query timeout + } + Logging struct { Level string // Stores the configured logging level Output string // Stores the configured logging output diff --git a/db/executor.go b/db/executor.go index f4cb6beb..4ba10f26 100644 --- a/db/executor.go +++ b/db/executor.go @@ -22,6 +22,7 @@ import ( "runtime/debug" + "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/kvs" "github.com/abcum/surreal/log" "github.com/abcum/surreal/sql" @@ -267,6 +268,22 @@ func (e *executor) operate(ctx context.Context, stm sql.Statement) (res []interf // can monitor the running time, and ensure // it runs no longer than specified. + if cnf.Settings.Query.Timeout > 0 { + if perm(ctx) != cnf.AuthKV { + ctx, canc = context.WithTimeout(ctx, cnf.Settings.Query.Timeout) + defer func() { + if tim := ctx.Err(); err == nil && tim != nil { + res, err = nil, &TimerError{timer: cnf.Settings.Query.Timeout} + } + canc() + }() + } + } + + // Mark the beginning of this statement so we + // can monitor the running time, and ensure + // it runs no longer than specified. + if stm, ok := stm.(sql.KillableStatement); ok { if dur := stm.Duration(); dur > 0 { ctx, canc = context.WithTimeout(ctx, dur)