diff --git a/test/co_test.go b/test/co_test.go index d4e5dfc3..a8c17e97 100644 --- a/test/co_test.go +++ b/test/co_test.go @@ -16,10 +16,13 @@ limitations under the License. package test import ( + "fmt" + "reflect" "testing" csi "github.com/container-storage-interface/spec/lib/go/csi/v0" gomock "github.com/golang/mock/gomock" + "github.com/golang/protobuf/proto" mock_driver "github.com/kubernetes-csi/csi-test/driver" mock_utils "github.com/kubernetes-csi/csi-test/utils" "golang.org/x/net/context" @@ -58,6 +61,24 @@ func TestPluginInfoResponse(t *testing.T) { } } +type pbMatcher struct { + x proto.Message +} + +func (p pbMatcher) Matches(x interface{}) bool { + y := x.(proto.Message) + return proto.Equal(p.x, y) +} + +func (p pbMatcher) String() string { + return fmt.Sprintf("pb equal to %v", p.x) +} + +func pbMatch(x interface{}) gomock.Matcher { + v := x.(proto.Message) + return &pbMatcher{v} +} + func TestGRPCGetPluginInfoReponse(t *testing.T) { // Setup mock @@ -79,7 +100,7 @@ func TestGRPCGetPluginInfoReponse(t *testing.T) { // Setup expectation // !IMPORTANT!: Must set context expected value to gomock.Any() to match any value - driver.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, nil).Times(1) + driver.EXPECT().GetPluginInfo(gomock.Any(), pbMatch(in)).Return(out, nil).Times(1) // Create a new RPC server := mock_driver.NewMockCSIDriver(&mock_driver.MockCSIDriverServers{ @@ -103,3 +124,65 @@ func TestGRPCGetPluginInfoReponse(t *testing.T) { t.Errorf("Unknown name: %s\n", name) } } + +func TestGRPCAttach(t *testing.T) { + + // Setup mock + m := gomock.NewController(&mock_utils.SafeGoroutineTester{}) + defer m.Finish() + driver := mock_driver.NewMockControllerServer(m) + + // Setup input + defaultVolumeID := "myname" + defaultNodeID := "MyNodeID" + defaultCaps := &csi.VolumeCapability{ + AccessType: &csi.VolumeCapability_Mount{ + Mount: &csi.VolumeCapability_MountVolume{}, + }, + AccessMode: &csi.VolumeCapability_AccessMode{ + Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER, + }, + } + publishVolumeInfo := map[string]string{ + "first": "foo", + "second": "bar", + "third": "baz", + } + defaultRequest := &csi.ControllerPublishVolumeRequest{ + VolumeId: defaultVolumeID, + NodeId: defaultNodeID, + VolumeCapability: defaultCaps, + Readonly: false, + } + + // Setup mock outout + out := &csi.ControllerPublishVolumeResponse{ + PublishInfo: publishVolumeInfo, + } + + // Setup expectation + // !IMPORTANT!: Must set context expected value to gomock.Any() to match any value + driver.EXPECT().ControllerPublishVolume(gomock.Any(), pbMatch(defaultRequest)).Return(out, nil).Times(1) + + // Create a new RPC + server := mock_driver.NewMockCSIDriver(&mock_driver.MockCSIDriverServers{ + Controller: driver, + }) + conn, err := server.Nexus() + if err != nil { + t.Errorf("Error: %s", err.Error()) + } + defer server.Close() + + // Make call + c := csi.NewControllerClient(conn) + r, err := c.ControllerPublishVolume(context.Background(), defaultRequest) + if err != nil { + t.Errorf("Error: %s", err.Error()) + } + + info := r.GetPublishInfo() + if !reflect.DeepEqual(info, publishVolumeInfo) { + t.Errorf("Invalid publish info: %v", info) + } +}