diff --git a/cli/setup.go b/cli/setup.go index f8a762e3..2c545f6e 100644 --- a/cli/setup.go +++ b/cli/setup.go @@ -45,6 +45,55 @@ func setup() { log.Fatal("Specify a valid data store configuration path") } + if strings.HasPrefix(opts.DB.Cert.CA, "-----") { + var err error + var doc *os.File + if doc, err = os.Create("db.ca"); err != nil { + log.Fatal("Can not decode PEM encoded CA into db.ca") + } + doc.Write([]byte(opts.DB.Cert.CA)) + doc.Close() + opts.Cert.Crt = "db.ca" + } + + if strings.HasPrefix(opts.DB.Cert.Crt, "-----") { + var err error + var doc *os.File + if doc, err = os.Create("db.key"); err != nil { + log.Fatal("Can not decode PEM encoded certificate into db.crt") + } + doc.Write([]byte(opts.DB.Cert.Crt)) + doc.Close() + opts.Cert.Crt = "db.crt" + } + + if strings.HasPrefix(opts.DB.Cert.Key, "-----") { + var err error + var doc *os.File + if doc, err = os.Create("db.crt"); err != nil { + log.Fatal("Can not decode PEM encoded private key into db.key") + } + doc.Write([]byte(opts.DB.Cert.Key)) + doc.Close() + opts.Cert.Crt = "db.key" + } + + if opts.DB.Cert.CA != "" || opts.DB.Cert.Crt != "" || opts.DB.Cert.Key != "" { + opts.DB.Cert.SSL = true + } + + if opts.DB.Cert.CA == "" && opts.DB.Cert.SSL { + log.Fatal("Specify a valid PEM encoded CA file.") + } + + if opts.DB.Cert.Crt == "" && opts.DB.Cert.SSL { + log.Fatal("Specify a valid PEM encoded certificate file.") + } + + if opts.DB.Cert.Key == "" && opts.DB.Cert.SSL { + log.Fatal("Specify a valid PEM encoded private key file.") + } + // -------------------------------------------------- // Auth // -------------------------------------------------- diff --git a/cli/start.go b/cli/start.go index 70e373b0..8bc54930 100644 --- a/cli/start.go +++ b/cli/start.go @@ -72,6 +72,9 @@ func init() { startCmd.PersistentFlags().StringVar(&opts.DB.Base, "db-base", "", flag("db-base")) startCmd.PersistentFlags().StringVar(&opts.DB.Path, "db-path", "", flag("db-path")) + startCmd.PersistentFlags().StringVar(&opts.DB.Cert.CA, "db-ca", "", "Path to the CA file used to connect to the remote database.") + startCmd.PersistentFlags().StringVar(&opts.DB.Cert.Crt, "db-crt", "", "Path to the certificate file used to connect to the remote database.") + startCmd.PersistentFlags().StringVar(&opts.DB.Cert.Key, "db-key", "", "Path to the private key file used to connect to the remote database.") startCmd.PersistentFlags().IntVar(&opts.Port.Tcp, "port-tcp", 0, flag("port-tcp")) startCmd.PersistentFlags().IntVar(&opts.Port.Web, "port-web", 0, flag("port-web")) diff --git a/cnf/cnf.go b/cnf/cnf.go index 21446279..a81863e9 100644 --- a/cnf/cnf.go +++ b/cnf/cnf.go @@ -23,6 +23,12 @@ type Options struct { Host string // Surreal host to connect to Port string // Surreal port to connect to Base string // Base key to use in KV stores + Cert struct { + CA string + Crt string + Key string + SSL bool + } } Port struct { diff --git a/kvs/mysql/main.go b/kvs/mysql/main.go index 7e8e7b6c..ae0bb54c 100644 --- a/kvs/mysql/main.go +++ b/kvs/mysql/main.go @@ -15,10 +15,15 @@ package mysql import ( - "strings" + "fmt" + "regexp" + + "crypto/tls" + "crypto/x509" + "io/ioutil" "database/sql" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/kvs" @@ -32,9 +37,12 @@ func New(opts *cnf.Options) (ds kvs.DS, err error) { var db *sql.DB - path := strings.TrimLeft(opts.DB.Path, "mysql://") + opts.DB.Path, err = config(opts) + if err != nil { + return + } - db, err = sql.Open("mysql", path) + db, err = sql.Open("mysql", opts.DB.Path) if err != nil { return } @@ -42,3 +50,49 @@ func New(opts *cnf.Options) (ds kvs.DS, err error) { return &DS{db: db, ck: opts.DB.Key}, err } + +func config(opts *cnf.Options) (path string, err error) { + + re := regexp.MustCompile(`^mysql://` + + `((?:(?P.*?)(?::(?P.*))?@))?` + + `(?:(?:(?P[^\/]*))?)?` + + `\/(?P.*?)` + + `(?:\?(?P[^\?]*))?$`) + + ma := re.FindStringSubmatch(opts.DB.Path) + + if len(ma) == 0 || ma[4] == "" || ma[5] == "" { + err = fmt.Errorf("Specify a valid data store configuration path. Use the help command for further instructions.") + } + + if opts.DB.Cert.SSL { + pool := x509.NewCertPool() + pem, err := ioutil.ReadFile(opts.DB.Cert.CA) + if err != nil { + err = fmt.Errorf("Could not read file %s", opts.DB.Cert.CA) + } + if ok := pool.AppendCertsFromPEM(pem); !ok { + return "", fmt.Errorf("Could not read file %s", opts.DB.Cert.CA) + } + cert := make([]tls.Certificate, 0, 1) + pair, err := tls.LoadX509KeyPair(opts.DB.Cert.Crt, opts.DB.Cert.Key) + if err != nil { + return "", err + } + cert = append(cert, pair) + mysql.RegisterTLSConfig("custom", &tls.Config{ + RootCAs: pool, + Certificates: cert, + InsecureSkipVerify: true, + }) + } + + if opts.DB.Cert.SSL { + path += fmt.Sprintf("%stcp(%s)/%s?tls=custom", ma[1], ma[4], ma[5]) + } else { + path += fmt.Sprintf("%stcp(%s)/%s", ma[1], ma[4], ma[5]) + } + + return + +} diff --git a/kvs/pgsql/main.go b/kvs/pgsql/main.go index 2cf295d3..a17fc2e1 100644 --- a/kvs/pgsql/main.go +++ b/kvs/pgsql/main.go @@ -15,7 +15,8 @@ package pgsql import ( - "strings" + "fmt" + "regexp" "database/sql" _ "github.com/lib/pq" @@ -32,9 +33,12 @@ func New(opts *cnf.Options) (ds kvs.DS, err error) { var db *sql.DB - path := strings.TrimLeft(opts.DB.Path, "pgsql://") + opts.DB.Path, err = config(opts) + if err != nil { + return + } - db, err = sql.Open("postgres", path) + db, err = sql.Open("postgres", opts.DB.Path) if err != nil { return } @@ -42,3 +46,27 @@ func New(opts *cnf.Options) (ds kvs.DS, err error) { return &DS{db: db, ck: opts.DB.Key}, err } + +func config(opts *cnf.Options) (path string, err error) { + + re := regexp.MustCompile(`^mysql://` + + `((?:(?P.*?)(?::(?P.*))?@))?` + + `(?:(?:(?P[^\/]*))?)?` + + `\/(?P.*?)` + + `(?:\?(?P[^\?]*))?$`) + + ma := re.FindStringSubmatch(opts.DB.Path) + + if len(ma) == 0 || ma[4] == "" || ma[5] == "" { + err = fmt.Errorf("Specify a valid data store configuration path. Use the help command for further instructions.") + } + + if opts.DB.Cert.SSL { + path += fmt.Sprintf("postgres://%s%s/%s?sslmode=verify-ca&sslrootcert=%s&sslcert=%s&sslkey=%s", ma[1], ma[4], ma[5], opts.DB.Cert.CA, opts.DB.Cert.Crt, opts.DB.Cert.Key) + } else { + path += fmt.Sprintf("postgres://%s%s/%s", ma[1], ma[4], ma[5]) + } + + return + +}