diff --git a/pkg/eval/options.go b/pkg/eval/options.go index 8e89bcb1..09ab52c9 100644 --- a/pkg/eval/options.go +++ b/pkg/eval/options.go @@ -1,12 +1,10 @@ package eval import ( - "fmt" "reflect" "src.elv.sh/pkg/eval/vals" "src.elv.sh/pkg/parse" - "src.elv.sh/pkg/strutil" ) // UnknownOption is thrown by a native function when called with an unknown option. @@ -28,31 +26,18 @@ type RawOptions map[string]interface{} // with options. A field named FieldName corresponds to the option named // field-name. Options that don't have corresponding fields in the struct causes // an error. +// +// Similar to vals.ScanMapToGo, but requires rawOpts to contain a subset of keys +// supported by the struct. func scanOptions(rawOpts RawOptions, ptr interface{}) error { - ptrValue := reflect.ValueOf(ptr) - if ptrValue.Kind() != reflect.Ptr || ptrValue.Elem().Kind() != reflect.Struct { - return fmt.Errorf( - "internal bug: need struct ptr to scan options, got %T", ptr) - } - - // fieldIdxForOpt maps option name to the index of field in `struc`. - fieldIdxForOpt := make(map[string]int) - struc := ptrValue.Elem() - for i := 0; i < struc.Type().NumField(); i++ { - if !struc.Field(i).CanSet() { - continue // ignore unexported fields - } - f := struc.Type().Field(i) - optName := strutil.CamelToDashed(f.Name) - fieldIdxForOpt[optName] = i - } - + _, keyIdx := vals.StructFieldsInfo(reflect.TypeOf(ptr).Elem()) + structValue := reflect.ValueOf(ptr).Elem() for k, v := range rawOpts { - fieldIdx, ok := fieldIdxForOpt[k] + fieldIdx, ok := keyIdx[k] if !ok { return UnknownOption{k} } - err := vals.ScanToGo(v, struc.Field(fieldIdx).Addr().Interface()) + err := vals.ScanToGo(v, structValue.Field(fieldIdx).Addr().Interface()) if err != nil { return err } diff --git a/pkg/eval/options_test.go b/pkg/eval/options_test.go index 7be0a462..6f5081a9 100644 --- a/pkg/eval/options_test.go +++ b/pkg/eval/options_test.go @@ -1,45 +1,30 @@ package eval import ( + "reflect" "testing" + + . "src.elv.sh/pkg/tt" ) type opts struct { - FooBar string - Min int - ignore bool // this should be ignored since it isn't exported -} - -var scanOptionsTests = []struct { - rawOpts RawOptions - preScan opts - postScan opts - err error -}{ - {RawOptions{"foo-bar": "lorem ipsum"}, - opts{}, opts{FooBar: "lorem ipsum"}, nil}, - // Since "ignore" is not exported it will result in an error when used. - {RawOptions{"ignore": true}, - opts{}, opts{ignore: false}, UnknownOption{"ignore"}}, + Foo string + bar int } func TestScanOptions(t *testing.T) { - // scanOptions requires a pointer to struct. - err := scanOptions(RawOptions{}, opts{}) - if err == nil { - t.Errorf("Scan should have reported invalid options arg error") + // A wrapper of ScanOptions, to make it easier to test + wrapper := func(src RawOptions, dstInit interface{}) (interface{}, error) { + ptr := reflect.New(reflect.TypeOf(dstInit)) + ptr.Elem().Set(reflect.ValueOf(dstInit)) + err := scanOptions(src, ptr.Interface()) + return ptr.Elem().Interface(), err } - for _, test := range scanOptionsTests { - opts := test.preScan - err := scanOptions(test.rawOpts, &opts) - - if ((err == nil) != (test.err == nil)) || - (err != nil && test.err != nil && err.Error() != test.err.Error()) { - t.Errorf("Scan error mismatch %v: want %q, got %q", test.rawOpts, test.err, err) - } - if opts != test.postScan { - t.Errorf("Scan %v => %v, want %v", test.rawOpts, opts, test.postScan) - } - } + Test(t, Fn("scanOptions", wrapper), Table{ + Args(RawOptions{"foo": "lorem ipsum"}, opts{}). + Rets(opts{Foo: "lorem ipsum"}, nil), + Args(RawOptions{"bar": 20}, opts{bar: 10}). + Rets(opts{bar: 10}, UnknownOption{"bar"}), + }) } diff --git a/pkg/eval/vals/conversion.go b/pkg/eval/vals/conversion.go index c95d5182..92acaee3 100644 --- a/pkg/eval/vals/conversion.go +++ b/pkg/eval/vals/conversion.go @@ -6,9 +6,11 @@ import ( "math/big" "reflect" "strconv" + "sync" "unicode/utf8" "src.elv.sh/pkg/eval/errs" + "src.elv.sh/pkg/strutil" ) // Conversion between "Go values" (those expected by native Go functions) and @@ -215,6 +217,68 @@ func ScanListElementsToGo(src List, ptrs ...interface{}) error { return nil } +// ScanMapToGo scans map elements into ptr, which must be a pointer to a struct. +// Struct field names are converted to map keys with CamelToDashed. +// +// The map may contains keys that don't correspond to struct fields, and it +// doesn't have to contain all keys that correspond to struct fields. +func ScanMapToGo(src Map, ptr interface{}) error { + // Iterate over the struct keys instead of the map: since extra keys are + // allowed, the map may be very big, while the size of the struct is bound. + keys, _ := StructFieldsInfo(reflect.TypeOf(ptr).Elem()) + structValue := reflect.ValueOf(ptr).Elem() + for i, key := range keys { + if key == "" { + continue + } + val, ok := src.Index(key) + if !ok { + continue + } + err := ScanToGo(val, structValue.Field(i).Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + +// StructFieldsInfo takes a type for a struct, and returns a slice for each +// field name, converted with CamelToDashed, and a reverse index. Unexported +// fields result in an empty string in the slice, and is omitted from the +// reverse index. +func StructFieldsInfo(t reflect.Type) ([]string, map[string]int) { + if info, ok := structFieldsInfoCache.Load(t); ok { + info := info.(structFieldsInfo) + return info.keys, info.keyIdx + } + info := makeStructFieldsInfo(t) + structFieldsInfoCache.Store(t, info) + return info.keys, info.keyIdx +} + +var structFieldsInfoCache sync.Map + +type structFieldsInfo struct { + keys []string + keyIdx map[string]int +} + +func makeStructFieldsInfo(t reflect.Type) structFieldsInfo { + keys := make([]string, t.NumField()) + keyIdx := make(map[string]int) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.PkgPath != "" { + continue + } + key := strutil.CamelToDashed(field.Name) + keyIdx[key] = i + keys[i] = key + } + return structFieldsInfo{keys, keyIdx} +} + // FromGo converts a Go value to an Elvish value. // // Exact numbers are normalized to the smallest types that can hold them, and diff --git a/pkg/eval/vals/conversion_test.go b/pkg/eval/vals/conversion_test.go index b237f1b7..aa74c021 100644 --- a/pkg/eval/vals/conversion_test.go +++ b/pkg/eval/vals/conversion_test.go @@ -159,6 +159,35 @@ func TestScanListElementsToGo(t *testing.T) { }) } +type aStruct struct { + Foo int + bar interface{} +} + +func TestScanMapToGo(t *testing.T) { + // A wrapper around ScanMapToGo, to make it easier to test. + scanMapToGo := func(src Map, dstInit interface{}) (interface{}, error) { + ptr := reflect.New(TypeOf(dstInit)) + ptr.Elem().Set(reflect.ValueOf(dstInit)) + err := ScanMapToGo(src, ptr.Interface()) + return ptr.Elem().Interface(), err + } + + Test(t, Fn("ScanListToGo", scanMapToGo), Table{ + Args(MakeMap("foo", "1"), aStruct{}).Rets(aStruct{Foo: 1}), + // More fields is OK + Args(MakeMap("foo", "1", "bar", "x"), aStruct{}).Rets(aStruct{Foo: 1}), + // Fewer fields is OK + Args(MakeMap(), aStruct{}).Rets(aStruct{}), + // Unexported fields are ignored + Args(MakeMap("bar", 20), aStruct{bar: 10}).Rets(aStruct{bar: 10}), + + // Conversion error + Args(MakeMap("foo", "a"), aStruct{}). + Rets(aStruct{}, cannotParseAs{"integer", "a"}), + }) +} + func TestFromGo(t *testing.T) { Test(t, Fn("FromGo", FromGo), Table{ // BigInt -> int, when in range