Skip to content

Commit

Permalink
pr feedback #1: update function name mocks and Mount calls in tests
Browse files Browse the repository at this point in the history
Signed-off-by: Maksim An <[email protected]>
  • Loading branch information
anmaxvl committed Oct 7, 2021
1 parent b229f22 commit 49e4ed6
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 84 deletions.
18 changes: 9 additions & 9 deletions internal/guest/storage/pmem/pmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (

// Test dependencies
var (
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createLinearTarget = dm.CreateZeroSectorLinearTarget
veritySetup = dm.CreateVerityTarget
removeDevice = dm.RemoveDevice
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget
createVerityTargetCalled = dm.CreateVerityTarget
removeDevice = dm.RemoveDevice
)

const (
Expand Down Expand Up @@ -93,7 +93,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.
// device instead of the original VPMem.
if mappingInfo != nil {
dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
if devicePath, err = createLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
return err
}
defer func() {
Expand All @@ -107,7 +107,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest)
if devicePath, err = veritySetup(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
if devicePath, err = createVerityTargetCalled(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
Expand Down
108 changes: 69 additions & 39 deletions internal/guest/storage/pmem/pmem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func clearTestDependencies() {
osMkdirAll = nil
osRemoveAll = nil
unixMount = nil
createLinearTarget = nil
veritySetup = nil
createZeroSectorLinearTarget = nil
createVerityTargetCalled = nil
removeDevice = nil
mountInternal = mount
}
Expand Down Expand Up @@ -323,7 +323,7 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
createLTCalled := false
createZSLTCalled := false

osMkdirAll = func(_ string, _ os.FileMode) error {
return nil
Expand All @@ -339,28 +339,33 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing
return nil
}

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createLTCalled = true
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZSLTCalled = true
if source != expectedSource {
t.Errorf("expected createLinearTarget source %s, got %s", expectedSource, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
}
if name != expectedLinearName {
t.Errorf("expected createLinearTarget name %s, got %s", expectedLinearName, name)
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(), 0, expectedTarget, mappingInfo, nil, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
expectedTarget,
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount: %s", err)
}
if !createLTCalled {
t.Fatalf("createLinearTarget not called")
if !createZSLTCalled {
t.Fatalf("createZeroSectorLinearTarget not called")
}
}

func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
Expand All @@ -370,7 +375,7 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
veritySetupCalled := false
createVerityTargetCalledCalled := false

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
Expand All @@ -381,28 +386,33 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
}
return nil
}
veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
veritySetupCalled = true
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalledCalled = true
if source != expectedSource {
t.Errorf("expected veritySetup source %s, got %s", expectedSource, source)
t.Errorf("expected createVerityTargetCalled source %s, got %s", expectedSource, source)
}
if name != expectedVerityName {
t.Errorf("expected veritySetup name %s, got %s", expectedVerityName, name)
t.Errorf("expected createVerityTargetCalled name %s, got %s", expectedVerityName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(), 0, expectedTarget, nil, verityInfo, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
expectedTarget,
nil,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected Mount failure: %s", err)
}
if !veritySetupCalled {
t.Fatal("veritySetup not called")
if !createVerityTargetCalledCalled {
t.Fatal("createVerityTargetCalled not called")
}
}

func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
Expand All @@ -421,23 +431,23 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
dmVerityCalled := false
mountCalled := false

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
dmLinearCalled = true
if source != expectedPMemDevice {
t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
if name != expectedLinearTarget {
t.Errorf("expected createLineartarget name %s, got %s", expectedLinearTarget, name)
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
}
return mapperLinearPath, nil
}
veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
dmVerityCalled = true
if source != mapperLinearPath {
t.Errorf("expected veritySetup source %s, got %s", mapperLinearPath, source)
t.Errorf("expected createVerityTargetCalled source %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected veritySetup target name %s, got %s", expectedVerityTarget, name)
t.Errorf("expected createVerityTargetCalled target name %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
Expand All @@ -450,15 +460,20 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
}

if err := Mount(
context.Background(), 0, "/foo", mapping, verityInfo, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mapping,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount call: %s", err)
}
if !dmLinearCalled {
t.Fatal("expected createLinearTarget call")
t.Fatal("expected createZeroSectorLinearTarget call")
}
if !dmVerityCalled {
t.Fatal("expected veritySetup call")
t.Fatal("expected createVerityTargetCalled call")
}
if !mountCalled {
t.Fatal("expected mountInternal call")
Expand All @@ -477,7 +492,7 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
removeDeviceCalled := false

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
Expand All @@ -492,7 +507,12 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", mappingInfo, nil, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand All @@ -512,7 +532,7 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
removeDeviceCalled := false

veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, _, _ string) error {
Expand All @@ -527,7 +547,12 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", nil, verity, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
nil,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand Down Expand Up @@ -555,18 +580,18 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin
rmLinearCalled := false
rmVerityCalled := false

createLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
if source != expectedPMemDevice {
t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
return mapperLinearPath, nil
}
veritySetup = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
if source != mapperLinearPath {
t.Errorf("expected veritySetup to be called with %s, got %s", mapperLinearPath, source)
t.Errorf("expected createVerityTargetCalled to be called with %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected veritySetup target %s, got %s", expectedVerityTarget, name)
t.Errorf("expected createVerityTargetCalled target %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
Expand All @@ -587,7 +612,12 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", mapping, verity, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mapping,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/guest/storage/scsi/scsi.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ var (

// controllerLunToName is stubbed to make testing `Mount` easier.
controllerLunToName = ControllerLunToName
// veritySetup is stubbed for unit testing `Mount`
veritySetup = dm.CreateVerityTarget
// createVerityTarget is stubbed for unit testing `Mount`
createVerityTarget = dm.CreateVerityTarget
// removeDevice is stubbed for unit testing `Mount`
removeDevice = dm.RemoveDevice
)
Expand Down Expand Up @@ -77,7 +77,7 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash)
if source, err = veritySetup(ctx, source, dmVerityName, verityInfo); err != nil {
if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
Expand Down
Loading

0 comments on commit 49e4ed6

Please sign in to comment.