diff --git a/pkg/gce-cloud-provider/compute/cloud-disk.go b/pkg/gce-cloud-provider/compute/cloud-disk.go index 23709cbe6..0fb2a8d03 100644 --- a/pkg/gce-cloud-provider/compute/cloud-disk.go +++ b/pkg/gce-cloud-provider/compute/cloud-disk.go @@ -92,6 +92,17 @@ func (d *CloudDisk) GetKind() string { } } +func (d *CloudDisk) GetStatus() string { + switch d.Type() { + case Zonal: + return d.ZonalDisk.Status + case Regional: + return d.RegionalDisk.Status + default: + return "Unknown" + } +} + // GetPDType returns the type of the PD as either 'pd-standard' or 'pd-ssd' The // "Type" field on the compute disk is stored as a url like // projects/project/zones/zone/diskTypes/pd-standard diff --git a/pkg/gce-cloud-provider/compute/fake-gce.go b/pkg/gce-cloud-provider/compute/fake-gce.go index c80a3d41c..bfbab953b 100644 --- a/pkg/gce-cloud-provider/compute/fake-gce.go +++ b/pkg/gce-cloud-provider/compute/fake-gce.go @@ -48,6 +48,9 @@ type FakeCloudProvider struct { pageTokens map[string]sets.String instances map[string]*computev1.Instance snapshots map[string]*computev1.Snapshot + + // marker to set disk status during InsertDisk operation. + mockDiskStatus string } var _ GCECompute = &FakeCloudProvider{} @@ -60,6 +63,8 @@ func CreateFakeCloudProvider(project, zone string, cloudDisks []*CloudDisk) (*Fa instances: map[string]*computev1.Instance{}, snapshots: map[string]*computev1.Snapshot{}, pageTokens: map[string]sets.String{}, + // A newly created disk is marked READY by default. + mockDiskStatus: "READY", } for _, d := range cloudDisks { fcp.disks[d.GetName()] = d @@ -250,6 +255,7 @@ func (cloud *FakeCloudProvider) InsertDisk(ctx context.Context, volKey *meta.Key Type: cloud.GetDiskTypeURI(volKey, params.DiskType), SelfLink: fmt.Sprintf("projects/%s/zones/%s/disks/%s", cloud.project, volKey.Zone, volKey.Name), SourceSnapshotId: snapshotID, + Status: cloud.mockDiskStatus, } if params.DiskEncryptionKMSKey != "" { diskToCreateGA.DiskEncryptionKey = &computev1.CustomerEncryptionKey{ @@ -265,6 +271,7 @@ func (cloud *FakeCloudProvider) InsertDisk(ctx context.Context, volKey *meta.Key Type: cloud.GetDiskTypeURI(volKey, params.DiskType), SelfLink: fmt.Sprintf("projects/%s/regions/%s/disks/%s", cloud.project, volKey.Region, volKey.Name), SourceSnapshotId: snapshotID, + Status: cloud.mockDiskStatus, } if params.DiskEncryptionKMSKey != "" { diskToCreateV1.DiskEncryptionKey = &computev1.CustomerEncryptionKey{ @@ -466,6 +473,10 @@ func (cloud *FakeCloudProvider) getGlobalSnapshotURI(snapshotName string) string snapshotName) } +func (cloud *FakeCloudProvider) UpdateDiskStatus(s string) { + cloud.mockDiskStatus = s +} + type FakeBlockingCloudProvider struct { *FakeCloudProvider ReadyToExecute chan chan struct{} diff --git a/pkg/gce-pd-csi-driver/controller.go b/pkg/gce-pd-csi-driver/controller.go index 24fa2b32e..353b3f7fb 100644 --- a/pkg/gce-pd-csi-driver/controller.go +++ b/pkg/gce-pd-csi-driver/controller.go @@ -62,6 +62,29 @@ const ( replicationTypeRegionalPD = "regional-pd" ) +func isDiskReady(disk *gce.CloudDisk) (bool, error) { + status := disk.GetStatus() + switch status { + case "READY": + return true, nil + case "FAILED": + return false, fmt.Errorf("Disk %s status is FAILED", disk.GetName()) + case "CREATING": + klog.V(4).Infof("Disk %s status is CREATING", disk.GetName()) + return false, nil + case "DELETING": + klog.V(4).Infof("Disk %s status is DELETING", disk.GetName()) + return false, nil + case "RESTORING": + klog.V(4).Infof("Disk %s status is RESTORING", disk.GetName()) + return false, nil + default: + klog.V(4).Infof("Disk %s status is: %s", disk.GetName(), status) + } + + return false, nil +} + func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { var err error // Validate arguments @@ -143,6 +166,16 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre if err != nil { return nil, status.Error(codes.AlreadyExists, fmt.Sprintf("CreateVolume disk already exists with same name and is incompatible: %v", err)) } + + ready, err := isDiskReady(existingDisk) + if err != nil { + return nil, status.Error(codes.Internal, fmt.Sprintf("CreateVolume disk %v had error checking ready status: %v", volKey, err)) + } + + if !ready { + return nil, status.Error(codes.Internal, fmt.Sprintf("CreateVolume disk %v is not ready", volKey)) + } + // If there is no validation error, immediately return success klog.V(4).Infof("CreateVolume succeeded for disk %v, it already exists and was compatible", volKey) return generateCreateVolumeResponse(existingDisk, capBytes, zones), nil @@ -187,6 +220,15 @@ func (gceCS *GCEControllerServer) CreateVolume(ctx context.Context, req *csi.Cre default: return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("CreateVolume replication type '%s' is not supported", params.ReplicationType)) } + + ready, err := isDiskReady(disk) + if err != nil { + return nil, status.Error(codes.Internal, fmt.Sprintf("CreateVolume disk %v had error checking ready status: %v", volKey, err)) + } + if !ready { + return nil, status.Error(codes.Internal, fmt.Sprintf("CreateVolume disk %v is not ready", volKey)) + } + klog.V(4).Infof("CreateVolume succeeded for disk %v", volKey) return generateCreateVolumeResponse(disk, capBytes, zones), nil diff --git a/pkg/gce-pd-csi-driver/controller_test.go b/pkg/gce-pd-csi-driver/controller_test.go index e6c9ab0fd..c415144e1 100644 --- a/pkg/gce-pd-csi-driver/controller_test.go +++ b/pkg/gce-pd-csi-driver/controller_test.go @@ -1549,3 +1549,111 @@ func TestVolumeOperationConcurrency(t *testing.T) { t.Errorf("Unexpected error: %v", err) } } + +func TestCreateVolumeDiskReady(t *testing.T) { + // Define test cases + testCases := []struct { + name string + diskStatus string + req *csi.CreateVolumeRequest + expVol *csi.Volume + expErrCode codes.Code + }{ + { + name: "disk status RESTORING", + diskStatus: "RESTORING", + req: &csi.CreateVolumeRequest{ + Name: "test-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCaps, + Parameters: stdParams, + }, + expErrCode: codes.Internal, + }, + { + name: "disk status CREATING", + diskStatus: "CREATING", + req: &csi.CreateVolumeRequest{ + Name: "test-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCaps, + Parameters: stdParams, + }, + expErrCode: codes.Internal, + }, + { + name: "disk status DELETING", + diskStatus: "DELETING", + req: &csi.CreateVolumeRequest{ + Name: "test-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCaps, + Parameters: stdParams, + }, + expErrCode: codes.Internal, + }, + { + name: "disk status FAILED", + diskStatus: "FAILED", + req: &csi.CreateVolumeRequest{ + Name: "test-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCaps, + Parameters: stdParams, + }, + expErrCode: codes.Internal, + }, + { + name: "success default", + diskStatus: "READY", + req: &csi.CreateVolumeRequest{ + Name: "test-name", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCaps, + Parameters: stdParams, + }, + expVol: &csi.Volume{ + CapacityBytes: common.GbToBytes(20), + VolumeId: testVolumeID, + VolumeContext: nil, + AccessibleTopology: stdTopology, + }, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fcp, err := gce.CreateFakeCloudProvider(project, zone, nil) + if err != nil { + t.Fatalf("Failed to create fake cloud provider: %v", err) + } + + // Setup hook to create new disks with given status. + fcp.UpdateDiskStatus(tc.diskStatus) + // Setup new driver each time so no interference + gceDriver := initGCEDriverWithCloudProvider(t, fcp) + // Start Test + resp, err := gceDriver.cs.CreateVolume(context.Background(), tc.req) + //check response + if err != nil { + serverError, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from err: %v", serverError) + } + if serverError.Code() != tc.expErrCode { + t.Fatalf("Expected error code: %v, got: %v. err : %v", tc.expErrCode, serverError.Code(), err) + } + return + } + if tc.expErrCode != codes.OK { + t.Fatalf("Expected error: %v, got no error", tc.expErrCode) + } + + vol := resp.GetVolume() + if !reflect.DeepEqual(vol, tc.expVol) { + t.Fatalf("Mismatch in expected vol %v, current volume: %v\n", tc.expVol, vol) + } + }) + } +}