Skip to content

Commit

Permalink
cdi: support custom and wildcard class for injection
Browse files Browse the repository at this point in the history
Signed-off-by: CrazyMax <[email protected]>
  • Loading branch information
crazy-max committed Feb 11, 2025
1 parent 88509a9 commit 10de940
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 64 deletions.
193 changes: 193 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ func testIntegration(t *testing.T, funcs ...func(t *testing.T, sb integration.Sa

integration.Run(t, integration.TestFuncs(
testCDI,
testCDIFirst,
testCDIWildcard,
testCDIClass,
), mirrors)
}

Expand Down Expand Up @@ -11054,3 +11057,193 @@ devices:
require.NoError(t, err)
require.Contains(t, strings.TrimSpace(string(dt2)), `BAR=injected`)
}

func testCDIFirst(t *testing.T, sb integration.Sandbox) {
if sb.Rootless() {
t.SkipNow()
}

integration.SkipOnPlatform(t, "windows")
c, err := New(sb.Context(), sb.Address())
require.NoError(t, err)
defer c.Close()

require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(`
cdiVersion: "0.3.0"
kind: "vendor1.com/device"
devices:
- name: foo
containerEdits:
env:
- FOO=injected
- name: bar
containerEdits:
env:
- BAR=injected
- name: baz
containerEdits:
env:
- BAZ=injected
- name: qux
containerEdits:
env:
- QUX=injected
`), 0600))

busybox := llb.Image("busybox:latest")
st := llb.Scratch()

run := func(cmd string, ro ...llb.RunOption) {
st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st)
}

run(`sh -c 'env|sort | tee first.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device")))

def, err := st.Marshal(sb.Context())
require.NoError(t, err)

destDir := t.TempDir()

_, err = c.Solve(sb.Context(), def, SolveOpt{
Exports: []ExportEntry{
{
Type: ExporterLocal,
OutputDir: destDir,
},
},
}, nil)
require.NoError(t, err)

dt, err := os.ReadFile(filepath.Join(destDir, "first.env"))
require.NoError(t, err)
require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`)
require.NotContains(t, strings.TrimSpace(string(dt)), `FOO=injected`)
require.NotContains(t, strings.TrimSpace(string(dt)), `BAZ=injected`)
require.NotContains(t, strings.TrimSpace(string(dt)), `QUX=injected`)
}

func testCDIWildcard(t *testing.T, sb integration.Sandbox) {
if sb.Rootless() {
t.SkipNow()
}

integration.SkipOnPlatform(t, "windows")
c, err := New(sb.Context(), sb.Address())
require.NoError(t, err)
defer c.Close()

require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(`
cdiVersion: "0.3.0"
kind: "vendor1.com/device"
devices:
- name: foo
containerEdits:
env:
- FOO=injected
- name: bar
containerEdits:
env:
- BAR=injected
`), 0600))

busybox := llb.Image("busybox:latest")
st := llb.Scratch()

run := func(cmd string, ro ...llb.RunOption) {
st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st)
}

run(`sh -c 'env|sort | tee all.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device=*")))

def, err := st.Marshal(sb.Context())
require.NoError(t, err)

destDir := t.TempDir()

_, err = c.Solve(sb.Context(), def, SolveOpt{
Exports: []ExportEntry{
{
Type: ExporterLocal,
OutputDir: destDir,
},
},
}, nil)
require.NoError(t, err)

dt, err := os.ReadFile(filepath.Join(destDir, "all.env"))
require.NoError(t, err)
require.Contains(t, strings.TrimSpace(string(dt)), `FOO=injected`)
require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`)
}

func testCDIClass(t *testing.T, sb integration.Sandbox) {
if sb.Rootless() {
t.SkipNow()
}

integration.SkipOnPlatform(t, "windows")
c, err := New(sb.Context(), sb.Address())
require.NoError(t, err)
defer c.Close()

require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(`
cdiVersion: "0.6.0"
kind: "vendor1.com/device"
annotations:
foo.bar.baz: FOO
devices:
- name: foo
annotations:
org.mobyproject.buildkit.device.class: class1
containerEdits:
env:
- FOO=injected
- name: bar
annotations:
org.mobyproject.buildkit.device.class: class1
containerEdits:
env:
- BAR=injected
- name: baz
annotations:
org.mobyproject.buildkit.device.class: class2
containerEdits:
env:
- BAZ=injected
- name: qux
containerEdits:
env:
- QUX=injected
`), 0600))

busybox := llb.Image("busybox:latest")
st := llb.Scratch()

run := func(cmd string, ro ...llb.RunOption) {
st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st)
}

run(`sh -c 'env|sort | tee class.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device=class1")))

def, err := st.Marshal(sb.Context())
require.NoError(t, err)

destDir := t.TempDir()

_, err = c.Solve(sb.Context(), def, SolveOpt{
Exports: []ExportEntry{
{
Type: ExporterLocal,
OutputDir: destDir,
},
},
}, nil)
require.NoError(t, err)

dt, err := os.ReadFile(filepath.Join(destDir, "class.env"))
require.NoError(t, err)
require.Contains(t, strings.TrimSpace(string(dt)), `FOO=injected`)
require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`)
require.NotContains(t, strings.TrimSpace(string(dt)), `BAZ=injected`)
require.NotContains(t, strings.TrimSpace(string(dt)), `QUX=injected`)
}
39 changes: 5 additions & 34 deletions executor/oci/spec_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/opencontainers/selinux/go-selinux/label"
"github.com/pkg/errors"
"golang.org/x/sys/unix"
"tags.cncf.io/container-device-interface/pkg/parser"
)

var (
Expand Down Expand Up @@ -153,47 +152,19 @@ func generateRlimitOpts(ulimits []*pb.Ulimit) ([]oci.SpecOpts, error) {

// genereateCDIOptions creates the OCI runtime spec options for injecting CDI
// devices.
func generateCDIOpts(manager *cdidevices.Manager, devices []*pb.CDIDevice) ([]oci.SpecOpts, error) {
if len(devices) == 0 {
func generateCDIOpts(manager *cdidevices.Manager, devs []*pb.CDIDevice) ([]oci.SpecOpts, error) {
if len(devs) == 0 {
return nil, nil
}

withCDIDevices := func(devices []*pb.CDIDevice) oci.SpecOpts {
withCDIDevices := func(devs []*pb.CDIDevice) oci.SpecOpts {
return func(ctx context.Context, _ oci.Client, c *containers.Container, s *specs.Spec) error {
if err := manager.Refresh(); err != nil {
bklog.G(ctx).Warnf("CDI registry refresh failed: %v", err)
}

registeredDevices := manager.ListDevices()
isDeviceRegistered := func(device *pb.CDIDevice) bool {
for _, d := range registeredDevices {
if device.Name == d.Name {
return true
}
}
return false
}

var dd []string
for _, d := range devices {
if d == nil {
continue
}
if _, _, _, err := parser.ParseQualifiedName(d.Name); err != nil {
return errors.Wrapf(err, "invalid CDI device name %s", d.Name)
}
if !isDeviceRegistered(d) && d.Optional {
bklog.G(ctx).Warnf("Optional CDI device %q is not registered", d.Name)
continue
}
dd = append(dd, d.Name)
}

bklog.G(ctx).Debugf("Injecting CDI devices %v", dd)
if err := manager.InjectDevices(s, dd...); err != nil {
if err := manager.InjectDevices(s, devs...); err != nil {
return errors.Wrapf(err, "CDI device injection failed")
}

// One crucial thing to keep in mind is that CDI device injection
// might add OCI Spec environment variables, hooks, and mounts as
// well. Therefore, it is important that none of the corresponding
Expand All @@ -203,7 +174,7 @@ func generateCDIOpts(manager *cdidevices.Manager, devices []*pb.CDIDevice) ([]oc
}

return []oci.SpecOpts{
withCDIDevices(devices),
withCDIDevices(devs),
}, nil
}

Expand Down
27 changes: 9 additions & 18 deletions frontend/dockerfile/instructions/commands_rundevice.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/moby/buildkit/util/suggest"
"github.com/pkg/errors"
"github.com/tonistiigi/go-csvvalue"
"tags.cncf.io/container-device-interface/pkg/parser"
)

var devicesKey = "dockerfile/run/devices"
Expand Down Expand Up @@ -75,28 +74,20 @@ func ParseDevice(val string) (*Device, error) {

d := &Device{}

for i, field := range fields {
// check if the first field is a valid device name
var firstFieldErr error
if i == 0 {
if _, _, _, firstFieldErr = parser.ParseQualifiedName(field); firstFieldErr == nil {
d.Name = field
continue
}
}

for _, field := range fields {
key, value, ok := strings.Cut(field, "=")
key = strings.ToLower(key)

if !ok {
if len(fields) == 1 && firstFieldErr != nil {
return nil, errors.Wrapf(firstFieldErr, "invalid device name %s", field)
}
switch key {
case "required":
d.Required = true
continue
default:
if d.Name == "" {
d.Name = field
continue
}
// any other option requires a value.
return nil, errors.Errorf("invalid field '%s' must be a key=value pair", field)
}
Expand All @@ -114,14 +105,14 @@ func ParseDevice(val string) (*Device, error) {
return nil, errors.Errorf("invalid value for %s: %s", key, value)
}
default:
if d.Name == "" {
d.Name = field
continue
}
allKeys := []string{"name", "required"}
return nil, suggest.WrapError(errors.Errorf("unexpected key '%s' in '%s'", key, field), key, allKeys, true)
}
}

if _, _, _, err := parser.ParseQualifiedName(d.Name); err != nil {
return nil, errors.Wrapf(err, "invalid device name %s", d.Name)
}

return d, nil
}
15 changes: 5 additions & 10 deletions frontend/dockerfile/instructions/commands_rundevice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func TestParseDevice(t *testing.T) {
expected: &Device{Name: "vendor1.com/device=foo", Required: false},
expectedErr: nil,
},
{
input: "vendor1.com/device",
expected: &Device{Name: "vendor1.com/device", Required: false},
expectedErr: nil,
},
{
input: "vendor1.com/device=foo,required",
expected: &Device{Name: "vendor1.com/device=foo", Required: true},
Expand Down Expand Up @@ -48,16 +53,6 @@ func TestParseDevice(t *testing.T) {
expected: nil,
expectedErr: errors.New("device name already set to vendor1.com/device=foo"),
},
{
input: "invalid-device-name",
expected: nil,
expectedErr: errors.New(`invalid device name invalid-device-name: unqualified device "invalid-device-name", missing vendor`),
},
{
input: "name=invalid-device-name",
expected: nil,
expectedErr: errors.New(`invalid device name invalid-device-name: unqualified device "invalid-device-name", missing vendor`),
},
}
for _, tt := range cases {
t.Run(tt.input, func(t *testing.T) {
Expand Down
Loading

0 comments on commit 10de940

Please sign in to comment.