Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend integrity protection of LCOW layers to SCSI devices #1170

Merged
merged 4 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions internal/guest/storage/pmem/pmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import (

// Test dependencies
var (
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget
createVerityTarget = dm.CreateVerityTarget
removeDevice = dm.RemoveDevice
Comment on lines +28 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, what's the reasoning behind doing this? i found it somewhat confusing and harder to follow what was going on because of this. What is gained by doing this? Testability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I added some tests to make sure that the device mapper targets are cleaned up on failure.

)

const (
Expand All @@ -32,8 +36,8 @@ const (
verityDeviceFmt = "dm-verity-pmem%d-%s"
)

// mountInternal mounts source to target via unix.Mount
func mountInternal(ctx context.Context, source, target string) (err error) {
// mount mounts source to target via unix.Mount
func mount(ctx context.Context, source, target string) (err error) {
if err := osMkdirAll(target, 0700); err != nil {
return err
}
Expand Down Expand Up @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.
// device instead of the original VPMem.
if mappingInfo != nil {
dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
return err
}
defer func() {
if err != nil {
if err := dm.RemoveDevice(dmLinearName); err != nil {
if err := removeDevice(dmLinearName); err != nil {
log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName)
}
}
Expand All @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest)
if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
if devicePath, err = createVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
if err != nil {
if err := dm.RemoveDevice(dmVerityName); err != nil {
if err := removeDevice(dmVerityName); err != nil {
log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName)
}
}
Expand Down
323 changes: 323 additions & 0 deletions internal/guest/storage/pmem/pmem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package pmem
import (
"context"
"fmt"
"github.com/Microsoft/hcsshim/internal/guest/prot"
"os"
"testing"

Expand All @@ -18,6 +19,10 @@ func clearTestDependencies() {
osMkdirAll = nil
osRemoveAll = nil
unixMount = nil
createZeroSectorLinearTarget = nil
createVerityTarget = nil
removeDevice = nil
mountInternal = mount
}

func Test_Mount_Mkdir_Fails_Error(t *testing.T) {
Expand Down Expand Up @@ -305,3 +310,321 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer {
func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer {
return &policy.MountMonitoringSecurityPolicyEnforcer{}
}

// device mapper tests
func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

mappingInfo := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
createZSLTCalled := false

osMkdirAll = func(_ string, _ os.FileMode) error {
return nil
}

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
}
if target != expectedTarget {
t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source)
}
return nil
}

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZSLTCalled = true
if source != expectedSource {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
}
if name != expectedLinearName {
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(),
0,
expectedTarget,
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount: %s", err)
}
if !createZSLTCalled {
t.Fatalf("createZeroSectorLinearTarget not called")
}
}

func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
createVerityTargetCalled := false

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
}
if target != expectedTarget {
t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target)
}
return nil
}
createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = true
if source != expectedSource {
t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source)
}
if name != expectedVerityName {
t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(),
0,
expectedTarget,
nil,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected Mount failure: %s", err)
}
if !createVerityTargetCalled {
t.Fatal("createVerityTarget not called")
}
}

func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
mapping := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
expectedPMemDevice := "/dev/pmem0"
mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
dmLinearCalled := false
dmVerityCalled := false
mountCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
dmLinearCalled = true
if source != expectedPMemDevice {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
if name != expectedLinearTarget {
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
}
return mapperLinearPath, nil
}
createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
dmVerityCalled = true
if source != mapperLinearPath {
t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
mountCalled = true
if source != mapperVerityPath {
t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
mapping,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount call: %s", err)
}
if !dmLinearCalled {
t.Fatal("expected createZeroSectorLinearTarget call")
}
if !dmVerityCalled {
t.Fatal("expected createVerityTarget call")
}
if !mountCalled {
t.Fatal("expected mountInternal call")
}
}

func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

mappingInfo := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedError := errors.New("mountInternal error")
expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
removeDeviceCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
return expectedError
}
removeDevice = func(name string) error {
removeDeviceCalled = true
if name != expectedTarget {
t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !removeDeviceCalled {
t.Fatal("expected removeDevice to be callled")
}
}

func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

verity := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
expectedError := errors.New("mountInternal error")
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
removeDeviceCalled := false

createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, _, _ string) error {
return expectedError
}
removeDevice = func(name string) error {
removeDeviceCalled = true
if name != expectedVerityTarget {
t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
nil,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !removeDeviceCalled {
t.Fatal("expected removeDevice to be called")
}
}

func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

mapping := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
verity := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedError := errors.New("mountInternal error")
expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
expectedPMemDevice := "/dev/pmem0"
mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
rmLinearCalled := false
rmVerityCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
if source != expectedPMemDevice {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
return mapperLinearPath, nil
}
createVerityTarget = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
if source != mapperLinearPath {
t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
removeDevice = func(name string) error {
if name != expectedLinearTarget && name != expectedVerityTarget {
t.Errorf("unexpected removeDevice target name %s", name)
}
if name == expectedLinearTarget {
rmLinearCalled = true
}
if name == expectedVerityTarget {
rmVerityCalled = true
}
return nil
}
mountInternal = func(_ context.Context, _, _ string) error {
return expectedError
}

if err := Mount(
context.Background(),
0,
"/foo",
mapping,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !rmLinearCalled {
t.Fatal("expected removeDevice for linear target to be called")
}
if !rmVerityCalled {
t.Fatal("expected removeDevice for verity target to be called")
}
}
Loading