Improve embedded field checking and defaults

This commit is contained in:
Tobie Morgan Hitchcock 2016-10-24 17:03:28 +01:00
parent 6bf826c466
commit 0af761d2db
2 changed files with 225 additions and 187 deletions

View file

@ -24,20 +24,17 @@ import (
"github.com/abcum/surreal/cnf" "github.com/abcum/surreal/cnf"
"github.com/abcum/surreal/sql" "github.com/abcum/surreal/sql"
"github.com/abcum/surreal/util/conv" "github.com/abcum/surreal/util/conv"
"github.com/abcum/surreal/util/data"
) )
func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (err error) { func (this *Doc) each(fld *sql.DefineFieldStatement) (err error) {
var e bool return this.current.Walk(func(key string, val interface{}) error {
var i interface{}
var c interface{}
i = initial.Get(fld.Name).Data() old := this.initial.Get(key).Data()
if fld.Readonly && i != nil { if fld.Readonly && old != nil {
current.Set(i, fld.Name) this.current.Set(old, key)
return return nil
} }
if fld.Code != "" { if fld.Code != "" {
@ -46,7 +43,7 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
vm := otto.New() vm := otto.New()
vm.Set("doc", current.Copy()) vm.Set("doc", this.current.Copy())
ret, err := vm.Run("(function() { " + fld.Code + " })()") ret, err := vm.Run("(function() { " + fld.Code + " })()")
if err != nil { if err != nil {
@ -54,10 +51,10 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
} }
if ret.IsUndefined() { if ret.IsUndefined() {
current.Del(fld.Name) this.current.Del(key)
} else { } else {
val, _ := ret.Export() val, _ := ret.Export()
current.Set(val, fld.Name) this.current.Set(val, key)
} }
} }
@ -67,7 +64,7 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
vm := lua.NewState() vm := lua.NewState()
defer vm.Close() defer vm.Close()
vm.SetGlobal("doc", toLUA(vm, current.Copy())) vm.SetGlobal("doc", toLUA(vm, this.current.Copy()))
if err := vm.DoString(fld.Code); err != nil { if err := vm.DoString(fld.Code); err != nil {
return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error()) return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error())
@ -76,231 +73,266 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
ret := vm.Get(-1) ret := vm.Get(-1)
if ret == lua.LNil { if ret == lua.LNil {
current.Del(fld.Name) this.current.Del(key)
} else { } else {
current.Set(frLUA(ret), fld.Name) this.current.Set(frLUA(ret), key)
} }
} }
} }
c = current.Get(fld.Name).Data() // Ensure that any defined fields are correctly
e = current.Exists(fld.Name) // formatted according to their defined type.
if fld.Default != nil && (c == nil || e == false) { if err = this.chck(fld, key, old, val); err != nil {
switch val := fld.Default.(type) { return err
}
// Ensure that any default fields are correctly
// formatted according to their defined type.
if err = this.chck(fld, key, old, val); err != nil {
return err
}
// Otherwise this is all good.
return nil
}, fld.Name)
}
func (this *Doc) chck(fld *sql.DefineFieldStatement, key string, old, val interface{}) (err error) {
var exi bool
// Ensure that any fields which have been set to
// null are reset back to default if specified.
exi = this.current.Exists(key)
val = this.current.Get(key).Data()
if fld.Default != nil && (exi == false || val == nil && fld.Notnull) {
switch def := fld.Default.(type) {
case sql.Null, *sql.Null: case sql.Null, *sql.Null:
current.Set(nil, fld.Name) this.current.Set(nil, key)
default: default:
current.Set(fld.Default, fld.Name) this.current.Set(fld.Default, key)
case sql.Ident: case sql.Ident:
current.Set(current.Get(val.ID).Data(), fld.Name) this.current.Set(this.current.Get(def.ID).Data(), key)
case *sql.Ident: case *sql.Ident:
current.Set(current.Get(val.ID).Data(), fld.Name) this.current.Set(this.current.Get(def.ID).Data(), key)
} }
} }
c = current.Get(fld.Name).Data() // Check to see if any field which has been set
e = current.Exists(fld.Name) // to defaults now satisfies the field constraints.
if fld.Notnull && e == true && c == nil { exi = this.current.Exists(key)
return fmt.Errorf("Field '%v' can't be null", fld.Name) val = this.current.Get(key).Data()
if fld.Notnull && exi == true && val == nil {
return fmt.Errorf("Field '%v' can't be null", key)
} }
if fld.Mandatory && e == false { if fld.Mandatory && exi == false {
return fmt.Errorf("Need to set field '%v'", fld.Name) return fmt.Errorf("Need to set field '%v'", key)
} }
if c != nil && fld.Type != "" { // Ensure that any defined fields are correctly
// formatted according to their defined type.
if val != nil {
switch fld.Type { switch fld.Type {
case "url": case "url":
if val, err := conv.ConvertToUrl(c); err == nil { if cnv, err := conv.ConvertToUrl(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a URL", fld.Name) return fmt.Errorf("Field '%v' needs to be a URL, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "uuid": case "uuid":
if val, err := conv.ConvertToUuid(c); err == nil { if cnv, err := conv.ConvertToUuid(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a UUID", fld.Name) return fmt.Errorf("Field '%v' needs to be a UUID, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "color": case "color":
if val, err := conv.ConvertToColor(c); err == nil { if cnv, err := conv.ConvertToColor(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a HEX or RGB color", fld.Name) return fmt.Errorf("Field '%v' needs to be a HEX or RGB color, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "email": case "email":
if val, err := conv.ConvertToEmail(c); err == nil { if cnv, err := conv.ConvertToEmail(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be an email address", fld.Name) return fmt.Errorf("Field '%v' needs to be an email address, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "phone": case "phone":
if val, err := conv.ConvertToPhone(c); err == nil { if cnv, err := conv.ConvertToPhone(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a phone number", fld.Name) return fmt.Errorf("Field '%v' needs to be a phone number, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "array": case "array":
if val, err := conv.ConvertToArray(c); err == nil { if cnv, err := conv.ConvertToArray(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be an array", fld.Name) return fmt.Errorf("Field '%v' needs to be an array, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "object": case "object":
if val, err := conv.ConvertToObject(c); err == nil { if cnv, err := conv.ConvertToObject(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be an object", fld.Name) return fmt.Errorf("Field '%v' needs to be an object, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "domain": case "domain":
if val, err := conv.ConvertToDomain(c); err == nil { if cnv, err := conv.ConvertToDomain(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a domain name", fld.Name) return fmt.Errorf("Field '%v' needs to be a domain name, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "base64": case "base64":
if val, err := conv.ConvertToBase64(c); err == nil { if cnv, err := conv.ConvertToBase64(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be base64 data", fld.Name) return fmt.Errorf("Field '%v' needs to be base64 data, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "string": case "string":
if val, err := conv.ConvertToString(c); err == nil { if cnv, err := conv.ConvertToString(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a string", fld.Name) return fmt.Errorf("Field '%v' needs to be a string, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "number": case "number":
if val, err := conv.ConvertToNumber(c); err == nil { if cnv, err := conv.ConvertToNumber(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a number", fld.Name) return fmt.Errorf("Field '%v' needs to be a number, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "double": case "double":
if val, err := conv.ConvertToDouble(c); err == nil { if cnv, err := conv.ConvertToDouble(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a double", fld.Name) return fmt.Errorf("Field '%v' needs to be a double, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "boolean": case "boolean":
if val, err := conv.ConvertToBoolean(c); err == nil { if cnv, err := conv.ConvertToBoolean(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a boolean", fld.Name) return fmt.Errorf("Field '%v' needs to be a boolean, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "datetime": case "datetime":
if val, err := conv.ConvertToDatetime(c); err == nil { if cnv, err := conv.ConvertToDatetime(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a datetime", fld.Name) return fmt.Errorf("Field '%v' needs to be a datetime, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "latitude": case "latitude":
if val, err := conv.ConvertToLatitude(c); err == nil { if cnv, err := conv.ConvertToLatitude(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a latitude value", fld.Name) return fmt.Errorf("Field '%v' needs to be a latitude value, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "longitude": case "longitude":
if val, err := conv.ConvertToLongitude(c); err == nil { if cnv, err := conv.ConvertToLongitude(val); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be a longitude value", fld.Name) return fmt.Errorf("Field '%v' needs to be a longitude value, but found '%v'", key, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case "custom": case "custom":
if val, err := conv.ConvertToOneOf(c, fld.Enum...); err == nil { if cnv, err := conv.ConvertToOneOf(val, fld.Enum...); err == nil {
current.Set(val, fld.Name) this.current.Set(cnv, key)
} else { } else {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be one of %v", fld.Name, fld.Enum) return fmt.Errorf("Field '%v' needs to be one of %v, but found '%v'", key, fld.Enum, val)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
@ -313,11 +345,11 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
if reg, err := regexp.Compile(fld.Match); err != nil { if reg, err := regexp.Compile(fld.Match); err != nil {
return fmt.Errorf("Regular expression /%v/ is invalid", fld.Match) return fmt.Errorf("Regular expression /%v/ is invalid", fld.Match)
} else { } else {
if !reg.MatchString(fmt.Sprintf("%v", c)) { if !reg.MatchString(fmt.Sprintf("%v", val)) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to match the regular expression /%v/", fld.Name, fld.Match) return fmt.Errorf("Field '%v' needs to match the regular expression /%v/", key, fld.Match)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
} }
@ -326,34 +358,43 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
if fld.Min != 0 { if fld.Min != 0 {
if c = current.Get(fld.Name).Data(); c != nil { if val = this.current.Get(key).Data(); val != nil {
switch now := c.(type) { switch now := val.(type) {
case []interface{}: case []interface{}:
if len(now) < int(fld.Min) { if len(now) < int(fld.Min) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to have at least %v items", fld.Name, fld.Min) return fmt.Errorf("Field '%v' needs to have at least %v items", key, fld.Min)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case string: case string:
if len(now) < int(fld.Min) { if len(now) < int(fld.Min) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to have at least %v characters", fld.Name, fld.Min) return fmt.Errorf("Field '%v' needs to have at least %v characters", key, fld.Min)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
}
}
case int64:
if now < int64(fld.Min) {
if fld.Validate {
return fmt.Errorf("Field '%v' needs to be >= %v", key, fld.Min)
} else {
this.current.Iff(old, key)
} }
} }
case float64: case float64:
if now < fld.Min { if now < float64(fld.Min) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be >= %v", fld.Name, fld.Min) return fmt.Errorf("Field '%v' needs to be >= %v", key, fld.Min)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
@ -365,34 +406,43 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
if fld.Max != 0 { if fld.Max != 0 {
if c = current.Get(fld.Name).Data(); c != nil { if val = this.current.Get(key).Data(); val != nil {
switch now := c.(type) { switch now := val.(type) {
case []interface{}: case []interface{}:
if len(now) > int(fld.Max) { if len(now) > int(fld.Max) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to have %v or fewer items", fld.Name, fld.Max) return fmt.Errorf("Field '%v' needs to have %v or fewer items", key, fld.Max)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
case string: case string:
if len(now) > int(fld.Max) { if len(now) > int(fld.Max) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to have %v or fewer characters", fld.Name, fld.Max) return fmt.Errorf("Field '%v' needs to have %v or fewer characters", key, fld.Max)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
}
}
case int64:
if now > int64(fld.Max) {
if fld.Validate {
return fmt.Errorf("Field '%v' needs to be <= %v", key, fld.Max)
} else {
this.current.Iff(old, key)
} }
} }
case float64: case float64:
if now > fld.Max { if now > float64(fld.Max) {
if fld.Validate { if fld.Validate {
return fmt.Errorf("Field '%v' needs to be <= %v", fld.Name, fld.Max) return fmt.Errorf("Field '%v' needs to be <= %v", key, fld.Max)
} else { } else {
current.Iff(i, fld.Name) this.current.Iff(old, key)
} }
} }
@ -402,33 +452,6 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (
} }
c = current.Get(fld.Name).Data()
e = current.Exists(fld.Name)
if fld.Default != nil && (c == nil || e == false) {
switch val := fld.Default.(type) {
case sql.Null, *sql.Null:
current.Set(nil, fld.Name)
default:
current.Set(fld.Default, fld.Name)
case sql.Ident:
current.Set(current.Get(val.ID).Data(), fld.Name)
case *sql.Ident:
current.Set(current.Get(val.ID).Data(), fld.Name)
}
}
c = current.Get(fld.Name).Data()
e = current.Exists(fld.Name)
if fld.Notnull && e == true && c == nil {
return fmt.Errorf("Field '%v' can't be null", fld.Name)
}
if fld.Mandatory && e == false {
return fmt.Errorf("Need to set field '%v'", fld.Name)
}
return return
} }

View file

@ -28,10 +28,18 @@ func (this *Doc) Merge(data []sql.Expr) (err error) {
this.getFields() this.getFields()
this.getIndexs() this.getIndexs()
if err = this.setFld(); err != nil {
return
}
if err = this.defFld(); err != nil { if err = this.defFld(); err != nil {
return return
} }
if err = this.setFld(); err != nil {
return
}
for _, part := range data { for _, part := range data {
switch expr := part.(type) { switch expr := part.(type) {
case *sql.DiffExpression: case *sql.DiffExpression:
@ -74,23 +82,30 @@ func (this *Doc) defFld() (err error) {
for _, fld := range this.fields { for _, fld := range this.fields {
e := this.current.Exists(fld.Name) this.current.Walk(func(key string, val interface{}) error {
if fld.Default != nil && e == false { v := this.current.Valid(key)
e := this.current.Exists(key)
if fld.Default != nil && (!e || !v && fld.Notnull) {
switch val := fld.Default.(type) { switch val := fld.Default.(type) {
case sql.Void, *sql.Void: case sql.Void, *sql.Void:
this.current.Del(fld.Name) this.current.Del(key)
case sql.Null, *sql.Null: case sql.Null, *sql.Null:
this.current.Set(nil, fld.Name) this.current.Set(nil, key)
default: default:
this.current.Set(fld.Default, fld.Name) this.current.Set(fld.Default, key)
case sql.Ident: case sql.Ident:
this.current.Set(this.current.Get(val.ID).Data(), fld.Name) this.current.Set(this.current.Get(val.ID).Data(), key)
case *sql.Ident: case *sql.Ident:
this.current.Set(this.current.Get(val.ID).Data(), fld.Name) this.current.Set(this.current.Get(val.ID).Data(), key)
} }
} }
return nil
}, fld.Name)
} }
return return
@ -100,7 +115,7 @@ func (this *Doc) defFld() (err error) {
func (this *Doc) mrgFld() (err error) { func (this *Doc) mrgFld() (err error) {
for _, fld := range this.fields { for _, fld := range this.fields {
if err = each(fld, this.initial, this.current); err != nil { if err = this.each(fld); err != nil {
return return
} }
} }