From 0af761d2db48c103c275c4642e0b3778d745c440 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Mon, 24 Oct 2016 17:03:28 +0100 Subject: [PATCH] Improve embedded field checking and defaults --- util/item/each.go | 367 ++++++++++++++++++++++++--------------------- util/item/merge.go | 45 ++++-- 2 files changed, 225 insertions(+), 187 deletions(-) diff --git a/util/item/each.go b/util/item/each.go index 61ff5338..4850bdf5 100644 --- a/util/item/each.go +++ b/util/item/each.go @@ -24,283 +24,315 @@ import ( "github.com/abcum/surreal/cnf" "github.com/abcum/surreal/sql" "github.com/abcum/surreal/util/conv" - "github.com/abcum/surreal/util/data" ) -func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) (err error) { +func (this *Doc) each(fld *sql.DefineFieldStatement) (err error) { - var e bool - var i interface{} - var c interface{} + return this.current.Walk(func(key string, val interface{}) error { - i = initial.Get(fld.Name).Data() + old := this.initial.Get(key).Data() - if fld.Readonly && i != nil { - current.Set(i, fld.Name) - return - } + if fld.Readonly && old != nil { + this.current.Set(old, key) + return nil + } - if fld.Code != "" { + if fld.Code != "" { - if cnf.Settings.DB.Lang == "js" { + if cnf.Settings.DB.Lang == "js" { - vm := otto.New() + vm := otto.New() - vm.Set("doc", current.Copy()) + vm.Set("doc", this.current.Copy()) + + ret, err := vm.Run("(function() { " + fld.Code + " })()") + if err != nil { + return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error()) + } + + if ret.IsUndefined() { + this.current.Del(key) + } else { + val, _ := ret.Export() + this.current.Set(val, key) + } - ret, err := vm.Run("(function() { " + fld.Code + " })()") - if err != nil { - return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error()) } - if ret.IsUndefined() { - current.Del(fld.Name) - } else { - val, _ := ret.Export() - current.Set(val, fld.Name) + if cnf.Settings.DB.Lang == "lua" { + + vm := lua.NewState() + defer vm.Close() + + vm.SetGlobal("doc", toLUA(vm, this.current.Copy())) + + if err := vm.DoString(fld.Code); err != nil { + return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error()) + } + + ret := vm.Get(-1) + + if ret == lua.LNil { + this.current.Del(key) + } else { + this.current.Set(frLUA(ret), key) + } + } } - if cnf.Settings.DB.Lang == "lua" { - - vm := lua.NewState() - defer vm.Close() - - vm.SetGlobal("doc", toLUA(vm, current.Copy())) - - if err := vm.DoString(fld.Code); err != nil { - return fmt.Errorf("Problem executing code: %v %v", fld.Code, err.Error()) - } - - ret := vm.Get(-1) - - if ret == lua.LNil { - current.Del(fld.Name) - } else { - current.Set(frLUA(ret), fld.Name) - } + // Ensure that any defined fields are correctly + // formatted according to their defined type. + if err = this.chck(fld, key, old, val); err != nil { + return err } - } + // Ensure that any default fields are correctly + // formatted according to their defined type. - c = current.Get(fld.Name).Data() - e = current.Exists(fld.Name) + if err = this.chck(fld, key, old, val); err != nil { + return err + } - if fld.Default != nil && (c == nil || e == false) { - switch val := fld.Default.(type) { + // Otherwise this is all good. + + return nil + + }, fld.Name) + +} + +func (this *Doc) chck(fld *sql.DefineFieldStatement, key string, old, val interface{}) (err error) { + + var exi bool + + // Ensure that any fields which have been set to + // null are reset back to default if specified. + + exi = this.current.Exists(key) + val = this.current.Get(key).Data() + + if fld.Default != nil && (exi == false || val == nil && fld.Notnull) { + switch def := fld.Default.(type) { case sql.Null, *sql.Null: - current.Set(nil, fld.Name) + this.current.Set(nil, key) default: - current.Set(fld.Default, fld.Name) + this.current.Set(fld.Default, key) case sql.Ident: - current.Set(current.Get(val.ID).Data(), fld.Name) + this.current.Set(this.current.Get(def.ID).Data(), key) case *sql.Ident: - current.Set(current.Get(val.ID).Data(), fld.Name) + this.current.Set(this.current.Get(def.ID).Data(), key) } } - c = current.Get(fld.Name).Data() - e = current.Exists(fld.Name) + // Check to see if any field which has been set + // to defaults now satisfies the field constraints. - if fld.Notnull && e == true && c == nil { - return fmt.Errorf("Field '%v' can't be null", fld.Name) + exi = this.current.Exists(key) + val = this.current.Get(key).Data() + + if fld.Notnull && exi == true && val == nil { + return fmt.Errorf("Field '%v' can't be null", key) } - if fld.Mandatory && e == false { - return fmt.Errorf("Need to set field '%v'", fld.Name) + if fld.Mandatory && exi == false { + return fmt.Errorf("Need to set field '%v'", key) } - if c != nil && fld.Type != "" { + // Ensure that any defined fields are correctly + // formatted according to their defined type. + + if val != nil { switch fld.Type { case "url": - if val, err := conv.ConvertToUrl(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToUrl(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a URL", fld.Name) + return fmt.Errorf("Field '%v' needs to be a URL, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "uuid": - if val, err := conv.ConvertToUuid(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToUuid(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a UUID", fld.Name) + return fmt.Errorf("Field '%v' needs to be a UUID, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "color": - if val, err := conv.ConvertToColor(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToColor(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a HEX or RGB color", fld.Name) + return fmt.Errorf("Field '%v' needs to be a HEX or RGB color, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "email": - if val, err := conv.ConvertToEmail(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToEmail(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be an email address", fld.Name) + return fmt.Errorf("Field '%v' needs to be an email address, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "phone": - if val, err := conv.ConvertToPhone(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToPhone(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a phone number", fld.Name) + return fmt.Errorf("Field '%v' needs to be a phone number, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "array": - if val, err := conv.ConvertToArray(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToArray(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be an array", fld.Name) + return fmt.Errorf("Field '%v' needs to be an array, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "object": - if val, err := conv.ConvertToObject(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToObject(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be an object", fld.Name) + return fmt.Errorf("Field '%v' needs to be an object, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "domain": - if val, err := conv.ConvertToDomain(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToDomain(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a domain name", fld.Name) + return fmt.Errorf("Field '%v' needs to be a domain name, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "base64": - if val, err := conv.ConvertToBase64(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToBase64(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be base64 data", fld.Name) + return fmt.Errorf("Field '%v' needs to be base64 data, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "string": - if val, err := conv.ConvertToString(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToString(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a string", fld.Name) + return fmt.Errorf("Field '%v' needs to be a string, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "number": - if val, err := conv.ConvertToNumber(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToNumber(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a number", fld.Name) + return fmt.Errorf("Field '%v' needs to be a number, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "double": - if val, err := conv.ConvertToDouble(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToDouble(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a double", fld.Name) + return fmt.Errorf("Field '%v' needs to be a double, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "boolean": - if val, err := conv.ConvertToBoolean(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToBoolean(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a boolean", fld.Name) + return fmt.Errorf("Field '%v' needs to be a boolean, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "datetime": - if val, err := conv.ConvertToDatetime(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToDatetime(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a datetime", fld.Name) + return fmt.Errorf("Field '%v' needs to be a datetime, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "latitude": - if val, err := conv.ConvertToLatitude(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToLatitude(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a latitude value", fld.Name) + return fmt.Errorf("Field '%v' needs to be a latitude value, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "longitude": - if val, err := conv.ConvertToLongitude(c); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToLongitude(val); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be a longitude value", fld.Name) + return fmt.Errorf("Field '%v' needs to be a longitude value, but found '%v'", key, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case "custom": - if val, err := conv.ConvertToOneOf(c, fld.Enum...); err == nil { - current.Set(val, fld.Name) + if cnv, err := conv.ConvertToOneOf(val, fld.Enum...); err == nil { + this.current.Set(cnv, key) } else { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be one of %v", fld.Name, fld.Enum) + return fmt.Errorf("Field '%v' needs to be one of %v, but found '%v'", key, fld.Enum, val) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } @@ -313,11 +345,11 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) ( if reg, err := regexp.Compile(fld.Match); err != nil { return fmt.Errorf("Regular expression /%v/ is invalid", fld.Match) } else { - if !reg.MatchString(fmt.Sprintf("%v", c)) { + if !reg.MatchString(fmt.Sprintf("%v", val)) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to match the regular expression /%v/", fld.Name, fld.Match) + return fmt.Errorf("Field '%v' needs to match the regular expression /%v/", key, fld.Match) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } } @@ -326,34 +358,43 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) ( if fld.Min != 0 { - if c = current.Get(fld.Name).Data(); c != nil { + if val = this.current.Get(key).Data(); val != nil { - switch now := c.(type) { + switch now := val.(type) { case []interface{}: if len(now) < int(fld.Min) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to have at least %v items", fld.Name, fld.Min) + return fmt.Errorf("Field '%v' needs to have at least %v items", key, fld.Min) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case string: if len(now) < int(fld.Min) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to have at least %v characters", fld.Name, fld.Min) + return fmt.Errorf("Field '%v' needs to have at least %v characters", key, fld.Min) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) + } + } + + case int64: + if now < int64(fld.Min) { + if fld.Validate { + return fmt.Errorf("Field '%v' needs to be >= %v", key, fld.Min) + } else { + this.current.Iff(old, key) } } case float64: - if now < fld.Min { + if now < float64(fld.Min) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be >= %v", fld.Name, fld.Min) + return fmt.Errorf("Field '%v' needs to be >= %v", key, fld.Min) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } @@ -365,34 +406,43 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) ( if fld.Max != 0 { - if c = current.Get(fld.Name).Data(); c != nil { + if val = this.current.Get(key).Data(); val != nil { - switch now := c.(type) { + switch now := val.(type) { case []interface{}: if len(now) > int(fld.Max) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to have %v or fewer items", fld.Name, fld.Max) + return fmt.Errorf("Field '%v' needs to have %v or fewer items", key, fld.Max) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } case string: if len(now) > int(fld.Max) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to have %v or fewer characters", fld.Name, fld.Max) + return fmt.Errorf("Field '%v' needs to have %v or fewer characters", key, fld.Max) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) + } + } + + case int64: + if now > int64(fld.Max) { + if fld.Validate { + return fmt.Errorf("Field '%v' needs to be <= %v", key, fld.Max) + } else { + this.current.Iff(old, key) } } case float64: - if now > fld.Max { + if now > float64(fld.Max) { if fld.Validate { - return fmt.Errorf("Field '%v' needs to be <= %v", fld.Name, fld.Max) + return fmt.Errorf("Field '%v' needs to be <= %v", key, fld.Max) } else { - current.Iff(i, fld.Name) + this.current.Iff(old, key) } } @@ -402,33 +452,6 @@ func each(fld *sql.DefineFieldStatement, initial *data.Doc, current *data.Doc) ( } - c = current.Get(fld.Name).Data() - e = current.Exists(fld.Name) - - if fld.Default != nil && (c == nil || e == false) { - switch val := fld.Default.(type) { - case sql.Null, *sql.Null: - current.Set(nil, fld.Name) - default: - current.Set(fld.Default, fld.Name) - case sql.Ident: - current.Set(current.Get(val.ID).Data(), fld.Name) - case *sql.Ident: - current.Set(current.Get(val.ID).Data(), fld.Name) - } - } - - c = current.Get(fld.Name).Data() - e = current.Exists(fld.Name) - - if fld.Notnull && e == true && c == nil { - return fmt.Errorf("Field '%v' can't be null", fld.Name) - } - - if fld.Mandatory && e == false { - return fmt.Errorf("Need to set field '%v'", fld.Name) - } - return } diff --git a/util/item/merge.go b/util/item/merge.go index 5c23240a..2ed7b149 100644 --- a/util/item/merge.go +++ b/util/item/merge.go @@ -28,10 +28,18 @@ func (this *Doc) Merge(data []sql.Expr) (err error) { this.getFields() this.getIndexs() + if err = this.setFld(); err != nil { + return + } + if err = this.defFld(); err != nil { return } + if err = this.setFld(); err != nil { + return + } + for _, part := range data { switch expr := part.(type) { case *sql.DiffExpression: @@ -74,22 +82,29 @@ func (this *Doc) defFld() (err error) { for _, fld := range this.fields { - e := this.current.Exists(fld.Name) + this.current.Walk(func(key string, val interface{}) error { - if fld.Default != nil && e == false { - switch val := fld.Default.(type) { - case sql.Void, *sql.Void: - this.current.Del(fld.Name) - case sql.Null, *sql.Null: - this.current.Set(nil, fld.Name) - default: - this.current.Set(fld.Default, fld.Name) - case sql.Ident: - this.current.Set(this.current.Get(val.ID).Data(), fld.Name) - case *sql.Ident: - this.current.Set(this.current.Get(val.ID).Data(), fld.Name) + v := this.current.Valid(key) + e := this.current.Exists(key) + + if fld.Default != nil && (!e || !v && fld.Notnull) { + switch val := fld.Default.(type) { + case sql.Void, *sql.Void: + this.current.Del(key) + case sql.Null, *sql.Null: + this.current.Set(nil, key) + default: + this.current.Set(fld.Default, key) + case sql.Ident: + this.current.Set(this.current.Get(val.ID).Data(), key) + case *sql.Ident: + this.current.Set(this.current.Get(val.ID).Data(), key) + } } - } + + return nil + + }, fld.Name) } @@ -100,7 +115,7 @@ func (this *Doc) defFld() (err error) { func (this *Doc) mrgFld() (err error) { for _, fld := range this.fields { - if err = each(fld, this.initial, this.current); err != nil { + if err = this.each(fld); err != nil { return } }