From 14382558863386fdd853bc9a1e8328416dd8ba5a Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Tue, 28 Jan 2025 12:51:06 -0800 Subject: [PATCH] fix(firestore): Convert key before seeting map entry (#11506) * fix(firestore): Convert key before seeting map entry * add unit tests --- firestore/from_value.go | 36 ++++++----- firestore/from_value_test.go | 121 +++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 16 deletions(-) diff --git a/firestore/from_value.go b/firestore/from_value.go index 75c176fac8c6..2ef040bbe896 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -24,12 +24,12 @@ import ( "cloud.google.com/go/internal/fields" ) -func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error { - v := reflect.ValueOf(x) - if v.Kind() != reflect.Ptr || v.IsNil() { +func setFromProtoValue(dest interface{}, vprotoSrc *pb.Value, c *Client) error { + destV := reflect.ValueOf(dest) + if destV.Kind() != reflect.Ptr || destV.IsNil() { return errors.New("firestore: nil or not a pointer") } - return setReflectFromProtoValue(v.Elem(), vproto, c) + return setReflectFromProtoValue(destV.Elem(), vprotoSrc, c) } // setReflectFromProtoValue sets vDest from a Firestore Value. @@ -277,29 +277,33 @@ func populateRepeated(vr reflect.Value, vals []*pb.Value, n int, c *Client) erro return nil } -// populateMap sets the elements of vm, which must be a map, from the -// corresponding elements of pm. +// populateMap sets the elements of destValueMap, which must be a map, from the +// corresponding elements of srcPropMap. // // Since a map value is not settable, this function always creates a new -// element for each corresponding map key. Existing values of vm are +// element for each corresponding map key. Existing values of destValueMap are // overwritten. This happens even if the map value is something like a pointer // to a struct, where we could in theory populate the existing struct value // instead of discarding it. This behavior matches encoding/json. -func populateMap(vm reflect.Value, pm map[string]*pb.Value, c *Client) error { - t := vm.Type() - if t.Key().Kind() != reflect.String { +func populateMap(destValueMap reflect.Value, srcPropMap map[string]*pb.Value, c *Client) error { + destValueMapType := destValueMap.Type() + if destValueMapType.Key().Kind() != reflect.String { return errors.New("firestore: map key type is not string") } - if vm.IsNil() { - vm.Set(reflect.MakeMap(t)) + if destValueMap.IsNil() { + destValueMap.Set(reflect.MakeMap(destValueMapType)) } - et := t.Elem() - for k, vproto := range pm { + et := destValueMapType.Elem() + for srcKey, srcVProto := range srcPropMap { el := reflect.New(et).Elem() - if err := setReflectFromProtoValue(el, vproto, c); err != nil { + if err := setReflectFromProtoValue(el, srcVProto, c); err != nil { return err } - vm.SetMapIndex(reflect.ValueOf(k), el) + keyToSet := reflect.ValueOf(srcKey) + if reflect.ValueOf(srcKey).CanConvert(destValueMapType.Key()) { + keyToSet = reflect.ValueOf(srcKey).Convert(destValueMapType.Key()) + } + destValueMap.SetMapIndex(keyToSet, el) } return nil } diff --git a/firestore/from_value_test.go b/firestore/from_value_test.go index d8ca4ceeb154..4325a3be8f96 100644 --- a/firestore/from_value_test.go +++ b/firestore/from_value_test.go @@ -16,6 +16,7 @@ package firestore import ( "encoding/json" + "errors" "fmt" "io" "math" @@ -579,3 +580,123 @@ func TestTypeString(t *testing.T) { } } } + +func TestPopulateMap(t *testing.T) { + c := &Client{} // Client is not used in populateMap, but required as a parameter + + type myString string + + // Test cases + cases := []struct { + name string + dest interface{} + src map[string]*pb.Value + want interface{} + wantErr bool + expectedErr error + }{ + { + name: "Valid string map", + dest: map[string]string{}, + src: map[string]*pb.Value{ + "key1": {ValueType: &pb.Value_StringValue{StringValue: "value1"}}, + "key2": {ValueType: &pb.Value_StringValue{StringValue: "value2"}}, + }, + want: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + { + name: "Aliased string key map", + dest: map[myString]string{}, + src: map[string]*pb.Value{ + "key1": {ValueType: &pb.Value_StringValue{StringValue: "value1"}}, + "key2": {ValueType: &pb.Value_StringValue{StringValue: "value2"}}, + }, + want: map[myString]string{ + "key1": "value1", + "key2": "value2", + }, + }, + { + name: "Valid int map", + dest: map[string]int{}, + src: map[string]*pb.Value{ + "key1": {ValueType: &pb.Value_IntegerValue{IntegerValue: 1}}, + "key2": {ValueType: &pb.Value_IntegerValue{IntegerValue: 2}}, + }, + want: map[string]int{ + "key1": 1, + "key2": 2, + }, + }, + { + name: "Valid interface map", + dest: map[string]interface{}{}, + src: map[string]*pb.Value{ + "key1": {ValueType: &pb.Value_StringValue{StringValue: "value1"}}, + "key2": {ValueType: &pb.Value_IntegerValue{IntegerValue: 2}}, + }, + want: map[string]interface{}{ + "key1": "value1", + "key2": int64(2), + }, + }, + { + name: "Non-string key", + dest: map[int]string{}, + src: map[string]*pb.Value{}, + wantErr: true, + expectedErr: errors.New("firestore: map key type is not string"), + }, + { + name: "Invalid value type", + dest: map[string]int{}, + src: map[string]*pb.Value{ + "key1": {ValueType: &pb.Value_StringValue{StringValue: "value1"}}, + }, + wantErr: true, + }, + { + name: "Map with Special Vector Type", + dest: map[string]interface{}{}, + src: map[string]*pb.Value{ + "type": {ValueType: &pb.Value_StringValue{StringValue: "vector"}}, + "value": {ValueType: &pb.Value_StringValue{StringValue: "some vector value"}}, + }, + want: map[string]interface{}{ + "type": "vector", + "value": "some vector value", + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + vDest := reflect.ValueOf(tc.dest) + err := populateMap(vDest, tc.src, c) + + if tc.wantErr { + if err == nil { + t.Errorf("Expected error, got nil") + } + if tc.expectedErr != nil && tc.expectedErr.Error() != err.Error() { + t.Errorf("Mismatched Error: Expected '%v', got '%v'", tc.expectedErr, err) + + } + + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + got := vDest.Interface() + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("populateMap() = %v, want %v", got, tc.want) + } + } + + }) + } +}