Enable CIDR whitelisting for root authentication

This commit is contained in:
Tobie Morgan Hitchcock 2017-02-21 00:09:42 +00:00
parent b70ef4594d
commit 05501ba79e
4 changed files with 32 additions and 5 deletions

View file

@ -16,6 +16,7 @@ package cli
import (
"fmt"
"net"
"os"
"regexp"
"strings"
@ -142,6 +143,16 @@ func setup() {
opts.Auth.Pass = string(rand.New(20))
}
//
for _, cidr := range opts.Auth.Addr {
_, subn, err := net.ParseCIDR(cidr)
if err != nil {
log.Fatalf("Invalid cidr %s. Please specify a valid CIDR address for --auth-addr", cidr)
}
opts.Auth.Nets = append(opts.Auth.Nets, subn)
}
// --------------------------------------------------
// Nodes
// --------------------------------------------------

View file

@ -70,7 +70,7 @@ func init() {
startCmd.PersistentFlags().StringVarP(&opts.Auth.Auth, "auth", "a", "root:root", "Master database authentication details.")
startCmd.PersistentFlags().StringVar(&opts.Auth.User, "auth-user", "", "The master username for the database. Use this as an alternative to the --auth flag.")
startCmd.PersistentFlags().StringVar(&opts.Auth.Pass, "auth-pass", "", "The master password for the database. Use this as an alternative to the --auth flag.")
startCmd.PersistentFlags().StringSliceVar(&opts.Auth.Addr, "auth-addr", nil, "The IP address ranges from which master authentication is possible.")
startCmd.PersistentFlags().StringSliceVar(&opts.Auth.Addr, "auth-addr", []string{"0.0.0.0/0"}, "The IP address ranges from which master authentication is possible.")
startCmd.PersistentFlags().StringVar(&opts.Cert.Crt, "cert-crt", "", "Path to the server certificate. Needed when running in secure mode.")
startCmd.PersistentFlags().StringVar(&opts.Cert.Key, "cert-key", "", "Path to the server private key. Needed when running in secure mode.")

View file

@ -14,7 +14,10 @@
package cnf
import "time"
import (
"net"
"time"
)
var Settings *Options
@ -69,6 +72,7 @@ type Options struct {
User string // Master authentication username
Pass string // Master authentication password
Addr []string
Nets []*net.IPNet
}
Node struct {

View file

@ -16,6 +16,7 @@ package web
import (
"fmt"
"net"
"bytes"
"strings"
@ -32,7 +33,20 @@ import (
"github.com/gorilla/websocket"
)
func cidr(ip net.IP, networks []*net.IPNet) bool {
for _, network := range networks {
if network.Contains(ip) {
return true
}
}
return false
}
func auth() fibre.MiddlewareFunc {
user := []byte(cnf.Settings.Auth.User)
pass := []byte(cnf.Settings.Auth.Pass)
return func(h fibre.HandlerFunc) fibre.HandlerFunc {
return func(c *fibre.Context) (err error) {
@ -115,10 +129,8 @@ func auth() fibre.MiddlewareFunc {
base, err := base64.StdEncoding.DecodeString(head[6:])
if err == nil {
if err == nil && cidr(c.IP(), cnf.Settings.Auth.Nets) {
user := []byte(cnf.Settings.Auth.User)
pass := []byte(cnf.Settings.Auth.Pass)
cred := bytes.SplitN(base, []byte(":"), 2)
if len(cred) == 2 && bytes.Equal(cred[0], user) && bytes.Equal(cred[1], pass) {