Prevent race conditions in when fetching document fields

This commit is contained in:
Tobie Morgan Hitchcock 2018-04-24 13:53:15 +01:00
parent b94a944299
commit 11a9fa05e6
2 changed files with 34 additions and 42 deletions

View file

@ -84,7 +84,7 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o
return val, queryIdentFailed
case doc != nil:
doc.Fetch(func(key string, val interface{}, path []string) interface{} {
fnc := func(key string, val interface{}, path []string) interface{} {
if len(path) > 0 {
switch res := val.(type) {
case []interface{}:
@ -96,9 +96,11 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o
}
}
return val
})
}
return e.fetch(ctx, doc.Get(val.ID).Data(), doc)
res := doc.Fetch(fnc, val.ID).Data()
return e.fetch(ctx, res, doc)
}
@ -108,7 +110,7 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o
if obj, ok := ctx.Value(s).(*data.Doc); ok {
obj.Fetch(func(key string, val interface{}, path []string) interface{} {
fnc := func(key string, val interface{}, path []string) interface{} {
if len(path) > 0 {
switch res := val.(type) {
case []interface{}:
@ -120,9 +122,11 @@ func (e *executor) fetch(ctx context.Context, val interface{}, doc *data.Doc) (o
}
}
return val
})
}
if res := obj.Get(val.ID).Data(); res != nil {
res := obj.Fetch(fnc, val.ID).Data()
if res != nil {
return e.fetch(ctx, res, doc)
}

View file

@ -39,7 +39,6 @@ const (
// Doc holds a reference to the core data object, or a selected path.
type Doc struct {
data interface{}
call Fetcher
}
// Fetcher is used when fetching values.
@ -58,11 +57,6 @@ func Consume(input interface{}) *Doc {
return &Doc{data: input}
}
// Consume converts a GO interface into a data object.
func ConsumeWithFetch(input interface{}, fetcher Fetcher) *Doc {
return &Doc{data: input, call: fetcher}
}
// Data returns the internal data object as an interface.
func (d *Doc) Data() interface{} {
return d.data
@ -73,11 +67,6 @@ func (d *Doc) Copy() *Doc {
return &Doc{data: deep.Copy(d.data)}
}
// Fetch adds a function to be used for retrieving and converting values.
func (d *Doc) Fetch(fetcher Fetcher) {
d.call = fetcher
}
// Encode encodes the data object to a byte slice.
func (d *Doc) Encode() (dst []byte) {
dst = pack.Encode(d.data)
@ -430,18 +419,12 @@ func (d *Doc) Exists(path ...string) bool {
}
if r == one {
if d.call != nil {
c[0] = d.call(p, c[0], path[k+1:])
}
return ConsumeWithFetch(c[0], d.call).Exists(path[k+1:]...)
return Consume(c[0]).Exists(path[k+1:]...)
}
if r == many {
for _, v := range c {
if d.call != nil {
v = d.call(p, v, path[k+1:])
}
if !ConsumeWithFetch(v, d.call).Exists(path[k+1:]...) {
if !Consume(v).Exists(path[k+1:]...) {
return false
}
}
@ -462,6 +445,11 @@ func (d *Doc) Exists(path ...string) bool {
// Get gets the value or values at a specified path.
func (d *Doc) Get(path ...string) *Doc {
return d.Fetch(nil, path...)
}
// Fetch gets the value or values at a specified path, allowing for a custom fetch method.
func (d *Doc) Fetch(call Fetcher, path ...string) *Doc {
path = d.path(path...)
@ -494,8 +482,8 @@ func (d *Doc) Get(path ...string) *Doc {
if m, ok := object.(map[string]interface{}); ok {
switch p {
default:
if d.call != nil {
object = d.call(p, m[p], path[k+1:])
if call != nil {
object = call(p, m[p], path[k+1:])
} else {
object = m[p]
}
@ -518,19 +506,19 @@ func (d *Doc) Get(path ...string) *Doc {
}
if r == one {
if d.call != nil {
c[0] = d.call(p, c[0], path[k+1:])
if call != nil {
c[0] = call(p, c[0], path[k+1:])
}
return ConsumeWithFetch(c[0], d.call).Get(path[k+1:]...)
return Consume(c[0]).Fetch(call, path[k+1:]...)
}
if r == many {
out := []interface{}{}
for _, v := range c {
if d.call != nil {
v = d.call(p, v, path[k+1:])
if call != nil {
v = call(p, v, path[k+1:])
}
res := ConsumeWithFetch(v, d.call).Get(path[k+1:]...)
res := Consume(v).Fetch(call, path[k+1:]...)
out = append(out, res.data)
}
return &Doc{data: out}
@ -612,7 +600,7 @@ func (d *Doc) Set(value interface{}, path ...string) (*Doc, error) {
object = a[i[0]]
continue
} else {
return ConsumeWithFetch(a[i[0]], d.call).Set(value, path[k+1:]...)
return Consume(a[i[0]]).Set(value, path[k+1:]...)
}
}
@ -623,7 +611,7 @@ func (d *Doc) Set(value interface{}, path ...string) (*Doc, error) {
a[i[j]] = value
out = append(out, value)
} else {
res, _ := ConsumeWithFetch(v, d.call).Set(value, path[k+1:]...)
res, _ := Consume(v).Set(value, path[k+1:]...)
if res.data != nil {
out = append(out, res.data)
}
@ -704,7 +692,7 @@ func (d *Doc) Del(path ...string) error {
continue
} else {
if len(c) != 0 {
return ConsumeWithFetch(c[0], d.call).Del(path[k+1:]...)
return Consume(c[0]).Del(path[k+1:]...)
}
}
}
@ -715,7 +703,7 @@ func (d *Doc) Del(path ...string) error {
continue
} else {
for _, v := range c {
ConsumeWithFetch(v, d.call).Del(path[k+1:]...)
Consume(v).Del(path[k+1:]...)
}
break
}
@ -772,7 +760,7 @@ func (d *Doc) ArrayAdd(value interface{}, path ...string) (*Doc, error) {
} else {
for _, v := range a {
if reflect.DeepEqual(v, value) {
return ConsumeWithFetch(a, d.call), nil
return Consume(a), nil
}
}
a = append(a, value)
@ -980,7 +968,7 @@ func (d *Doc) each(exec Iterator, prev []string) error {
var keep []string
keep = append(keep, prev...)
keep = append(keep, k)
ConsumeWithFetch(v, d.call).each(exec, keep)
Consume(v).each(exec, keep)
}
return nil
}
@ -995,7 +983,7 @@ func (d *Doc) each(exec Iterator, prev []string) error {
var keep []string
keep = append(keep, prev...)
keep = append(keep, fmt.Sprintf("[%d]", i))
ConsumeWithFetch(v, d.call).each(exec, keep)
Consume(v).each(exec, keep)
}
return nil
}
@ -1078,7 +1066,7 @@ func (d *Doc) walk(exec Iterator, prev []string, path ...string) error {
keep = append(keep, prev...)
keep = append(keep, path[:k]...)
keep = append(keep, fmt.Sprintf("[%d]", i[0]))
return ConsumeWithFetch(c[0], d.call).walk(exec, keep, path[k+1:]...)
return Consume(c[0]).walk(exec, keep, path[k+1:]...)
}
}
@ -1097,7 +1085,7 @@ func (d *Doc) walk(exec Iterator, prev []string, path ...string) error {
keep = append(keep, prev...)
keep = append(keep, path[:k]...)
keep = append(keep, fmt.Sprintf("[%d]", i[j]))
if err := ConsumeWithFetch(v, d.call).walk(exec, keep, path[k+1:]...); err != nil {
if err := Consume(v).walk(exec, keep, path[k+1:]...); err != nil {
return err
}
}