From 10de9408e3ab63c47a8577d948b36e50f00012ae Mon Sep 17 00:00:00 2001 From: CrazyMax <1951866+crazy-max@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:22:27 +0100 Subject: [PATCH] cdi: support custom and wildcard class for injection Signed-off-by: CrazyMax <1951866+crazy-max@users.noreply.github.com> --- client/client_test.go | 193 ++++++++++++++++++ executor/oci/spec_linux.go | 39 +--- .../instructions/commands_rundevice.go | 27 +-- .../instructions/commands_rundevice_test.go | 15 +- solver/llbsolver/cdidevices/manager.go | 95 ++++++++- 5 files changed, 305 insertions(+), 64 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index b13ab645a82e..497e8ebd084e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -277,6 +277,9 @@ func testIntegration(t *testing.T, funcs ...func(t *testing.T, sb integration.Sa integration.Run(t, integration.TestFuncs( testCDI, + testCDIFirst, + testCDIWildcard, + testCDIClass, ), mirrors) } @@ -11054,3 +11057,193 @@ devices: require.NoError(t, err) require.Contains(t, strings.TrimSpace(string(dt2)), `BAR=injected`) } + +func testCDIFirst(t *testing.T, sb integration.Sandbox) { + if sb.Rootless() { + t.SkipNow() + } + + integration.SkipOnPlatform(t, "windows") + c, err := New(sb.Context(), sb.Address()) + require.NoError(t, err) + defer c.Close() + + require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(` +cdiVersion: "0.3.0" +kind: "vendor1.com/device" +devices: +- name: foo + containerEdits: + env: + - FOO=injected +- name: bar + containerEdits: + env: + - BAR=injected +- name: baz + containerEdits: + env: + - BAZ=injected +- name: qux + containerEdits: + env: + - QUX=injected +`), 0600)) + + busybox := llb.Image("busybox:latest") + st := llb.Scratch() + + run := func(cmd string, ro ...llb.RunOption) { + st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st) + } + + run(`sh -c 'env|sort | tee first.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device"))) + + def, err := st.Marshal(sb.Context()) + require.NoError(t, err) + + destDir := t.TempDir() + + _, err = c.Solve(sb.Context(), def, SolveOpt{ + Exports: []ExportEntry{ + { + Type: ExporterLocal, + OutputDir: destDir, + }, + }, + }, nil) + require.NoError(t, err) + + dt, err := os.ReadFile(filepath.Join(destDir, "first.env")) + require.NoError(t, err) + require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`) + require.NotContains(t, strings.TrimSpace(string(dt)), `FOO=injected`) + require.NotContains(t, strings.TrimSpace(string(dt)), `BAZ=injected`) + require.NotContains(t, strings.TrimSpace(string(dt)), `QUX=injected`) +} + +func testCDIWildcard(t *testing.T, sb integration.Sandbox) { + if sb.Rootless() { + t.SkipNow() + } + + integration.SkipOnPlatform(t, "windows") + c, err := New(sb.Context(), sb.Address()) + require.NoError(t, err) + defer c.Close() + + require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(` +cdiVersion: "0.3.0" +kind: "vendor1.com/device" +devices: +- name: foo + containerEdits: + env: + - FOO=injected +- name: bar + containerEdits: + env: + - BAR=injected +`), 0600)) + + busybox := llb.Image("busybox:latest") + st := llb.Scratch() + + run := func(cmd string, ro ...llb.RunOption) { + st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st) + } + + run(`sh -c 'env|sort | tee all.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device=*"))) + + def, err := st.Marshal(sb.Context()) + require.NoError(t, err) + + destDir := t.TempDir() + + _, err = c.Solve(sb.Context(), def, SolveOpt{ + Exports: []ExportEntry{ + { + Type: ExporterLocal, + OutputDir: destDir, + }, + }, + }, nil) + require.NoError(t, err) + + dt, err := os.ReadFile(filepath.Join(destDir, "all.env")) + require.NoError(t, err) + require.Contains(t, strings.TrimSpace(string(dt)), `FOO=injected`) + require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`) +} + +func testCDIClass(t *testing.T, sb integration.Sandbox) { + if sb.Rootless() { + t.SkipNow() + } + + integration.SkipOnPlatform(t, "windows") + c, err := New(sb.Context(), sb.Address()) + require.NoError(t, err) + defer c.Close() + + require.NoError(t, os.WriteFile(filepath.Join(sb.CDISpecDir(), "vendor1-device.yaml"), []byte(` +cdiVersion: "0.6.0" +kind: "vendor1.com/device" +annotations: + foo.bar.baz: FOO +devices: +- name: foo + annotations: + org.mobyproject.buildkit.device.class: class1 + containerEdits: + env: + - FOO=injected +- name: bar + annotations: + org.mobyproject.buildkit.device.class: class1 + containerEdits: + env: + - BAR=injected +- name: baz + annotations: + org.mobyproject.buildkit.device.class: class2 + containerEdits: + env: + - BAZ=injected +- name: qux + containerEdits: + env: + - QUX=injected +`), 0600)) + + busybox := llb.Image("busybox:latest") + st := llb.Scratch() + + run := func(cmd string, ro ...llb.RunOption) { + st = busybox.Run(append(ro, llb.Shlex(cmd), llb.Dir("/wd"))...).AddMount("/wd", st) + } + + run(`sh -c 'env|sort | tee class.env'`, llb.AddCDIDevice(llb.CDIDeviceName("vendor1.com/device=class1"))) + + def, err := st.Marshal(sb.Context()) + require.NoError(t, err) + + destDir := t.TempDir() + + _, err = c.Solve(sb.Context(), def, SolveOpt{ + Exports: []ExportEntry{ + { + Type: ExporterLocal, + OutputDir: destDir, + }, + }, + }, nil) + require.NoError(t, err) + + dt, err := os.ReadFile(filepath.Join(destDir, "class.env")) + require.NoError(t, err) + require.Contains(t, strings.TrimSpace(string(dt)), `FOO=injected`) + require.Contains(t, strings.TrimSpace(string(dt)), `BAR=injected`) + require.NotContains(t, strings.TrimSpace(string(dt)), `BAZ=injected`) + require.NotContains(t, strings.TrimSpace(string(dt)), `QUX=injected`) +} diff --git a/executor/oci/spec_linux.go b/executor/oci/spec_linux.go index 1a394ac04d6b..7b437c28b8b0 100644 --- a/executor/oci/spec_linux.go +++ b/executor/oci/spec_linux.go @@ -26,7 +26,6 @@ import ( "github.com/opencontainers/selinux/go-selinux/label" "github.com/pkg/errors" "golang.org/x/sys/unix" - "tags.cncf.io/container-device-interface/pkg/parser" ) var ( @@ -153,47 +152,19 @@ func generateRlimitOpts(ulimits []*pb.Ulimit) ([]oci.SpecOpts, error) { // genereateCDIOptions creates the OCI runtime spec options for injecting CDI // devices. -func generateCDIOpts(manager *cdidevices.Manager, devices []*pb.CDIDevice) ([]oci.SpecOpts, error) { - if len(devices) == 0 { +func generateCDIOpts(manager *cdidevices.Manager, devs []*pb.CDIDevice) ([]oci.SpecOpts, error) { + if len(devs) == 0 { return nil, nil } - withCDIDevices := func(devices []*pb.CDIDevice) oci.SpecOpts { + withCDIDevices := func(devs []*pb.CDIDevice) oci.SpecOpts { return func(ctx context.Context, _ oci.Client, c *containers.Container, s *specs.Spec) error { if err := manager.Refresh(); err != nil { bklog.G(ctx).Warnf("CDI registry refresh failed: %v", err) } - - registeredDevices := manager.ListDevices() - isDeviceRegistered := func(device *pb.CDIDevice) bool { - for _, d := range registeredDevices { - if device.Name == d.Name { - return true - } - } - return false - } - - var dd []string - for _, d := range devices { - if d == nil { - continue - } - if _, _, _, err := parser.ParseQualifiedName(d.Name); err != nil { - return errors.Wrapf(err, "invalid CDI device name %s", d.Name) - } - if !isDeviceRegistered(d) && d.Optional { - bklog.G(ctx).Warnf("Optional CDI device %q is not registered", d.Name) - continue - } - dd = append(dd, d.Name) - } - - bklog.G(ctx).Debugf("Injecting CDI devices %v", dd) - if err := manager.InjectDevices(s, dd...); err != nil { + if err := manager.InjectDevices(s, devs...); err != nil { return errors.Wrapf(err, "CDI device injection failed") } - // One crucial thing to keep in mind is that CDI device injection // might add OCI Spec environment variables, hooks, and mounts as // well. Therefore, it is important that none of the corresponding @@ -203,7 +174,7 @@ func generateCDIOpts(manager *cdidevices.Manager, devices []*pb.CDIDevice) ([]oc } return []oci.SpecOpts{ - withCDIDevices(devices), + withCDIDevices(devs), }, nil } diff --git a/frontend/dockerfile/instructions/commands_rundevice.go b/frontend/dockerfile/instructions/commands_rundevice.go index 975c95454978..693f28bdc19d 100644 --- a/frontend/dockerfile/instructions/commands_rundevice.go +++ b/frontend/dockerfile/instructions/commands_rundevice.go @@ -7,7 +7,6 @@ import ( "github.com/moby/buildkit/util/suggest" "github.com/pkg/errors" "github.com/tonistiigi/go-csvvalue" - "tags.cncf.io/container-device-interface/pkg/parser" ) var devicesKey = "dockerfile/run/devices" @@ -75,28 +74,20 @@ func ParseDevice(val string) (*Device, error) { d := &Device{} - for i, field := range fields { - // check if the first field is a valid device name - var firstFieldErr error - if i == 0 { - if _, _, _, firstFieldErr = parser.ParseQualifiedName(field); firstFieldErr == nil { - d.Name = field - continue - } - } - + for _, field := range fields { key, value, ok := strings.Cut(field, "=") key = strings.ToLower(key) if !ok { - if len(fields) == 1 && firstFieldErr != nil { - return nil, errors.Wrapf(firstFieldErr, "invalid device name %s", field) - } switch key { case "required": d.Required = true continue default: + if d.Name == "" { + d.Name = field + continue + } // any other option requires a value. return nil, errors.Errorf("invalid field '%s' must be a key=value pair", field) } @@ -114,14 +105,14 @@ func ParseDevice(val string) (*Device, error) { return nil, errors.Errorf("invalid value for %s: %s", key, value) } default: + if d.Name == "" { + d.Name = field + continue + } allKeys := []string{"name", "required"} return nil, suggest.WrapError(errors.Errorf("unexpected key '%s' in '%s'", key, field), key, allKeys, true) } } - if _, _, _, err := parser.ParseQualifiedName(d.Name); err != nil { - return nil, errors.Wrapf(err, "invalid device name %s", d.Name) - } - return d, nil } diff --git a/frontend/dockerfile/instructions/commands_rundevice_test.go b/frontend/dockerfile/instructions/commands_rundevice_test.go index 92d103cbd847..79b9be79ff19 100644 --- a/frontend/dockerfile/instructions/commands_rundevice_test.go +++ b/frontend/dockerfile/instructions/commands_rundevice_test.go @@ -18,6 +18,11 @@ func TestParseDevice(t *testing.T) { expected: &Device{Name: "vendor1.com/device=foo", Required: false}, expectedErr: nil, }, + { + input: "vendor1.com/device", + expected: &Device{Name: "vendor1.com/device", Required: false}, + expectedErr: nil, + }, { input: "vendor1.com/device=foo,required", expected: &Device{Name: "vendor1.com/device=foo", Required: true}, @@ -48,16 +53,6 @@ func TestParseDevice(t *testing.T) { expected: nil, expectedErr: errors.New("device name already set to vendor1.com/device=foo"), }, - { - input: "invalid-device-name", - expected: nil, - expectedErr: errors.New(`invalid device name invalid-device-name: unqualified device "invalid-device-name", missing vendor`), - }, - { - input: "name=invalid-device-name", - expected: nil, - expectedErr: errors.New(`invalid device name invalid-device-name: unqualified device "invalid-device-name", missing vendor`), - }, } for _, tt := range cases { t.Run(tt.input, func(t *testing.T) { diff --git a/solver/llbsolver/cdidevices/manager.go b/solver/llbsolver/cdidevices/manager.go index 168daa81bf45..27a7c0dbb08d 100644 --- a/solver/llbsolver/cdidevices/manager.go +++ b/solver/llbsolver/cdidevices/manager.go @@ -4,12 +4,17 @@ import ( "context" "strings" + "github.com/moby/buildkit/solver/pb" + "github.com/moby/buildkit/util/bklog" "github.com/moby/locker" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/pkg/parser" ) +const deviceAnnotationClass = "org.mobyproject.buildkit.device.class" + var installers = map[string]Setup{} type Setup interface { @@ -76,11 +81,97 @@ func (m *Manager) Refresh() error { return m.cache.Refresh() } -func (m *Manager) InjectDevices(spec *specs.Spec, devs ...string) error { - _, err := m.cache.InjectDevices(spec, devs...) +func (m *Manager) InjectDevices(spec *specs.Spec, devs ...*pb.CDIDevice) error { + pdevs, err := m.ParseDevices(devs...) + if err != nil { + return err + } else if len(pdevs) == 0 { + return nil + } + bklog.G(context.TODO()).Debugf("Injecting devices %v", pdevs) + _, err = m.cache.InjectDevices(spec, pdevs...) return err } +func (m *Manager) ParseDevices(devs ...*pb.CDIDevice) ([]string, error) { + var out []string + for _, dev := range devs { + if dev == nil { + continue + } + pdev, err := m.parseDevice(dev) + if err != nil { + return nil, err + } else if len(pdev) == 0 { + continue + } + out = append(out, pdev...) + } + return out, nil +} + +func (m *Manager) parseDevice(dev *pb.CDIDevice) ([]string, error) { + var out []string + + kind, name, _ := strings.Cut(dev.Name, "=") + + // validate kind + if vendor, class := parser.ParseQualifier(kind); vendor == "" { + return nil, errors.Errorf("invalid device %q", dev.Name) + } else if err := parser.ValidateVendorName(vendor); err != nil { + return nil, errors.Wrapf(err, "invalid device %q", dev.Name) + } else if err := parser.ValidateClassName(class); err != nil { + return nil, errors.Wrapf(err, "invalid device %q", dev.Name) + } + + switch name { + case "": + // first device of kind if no name is specified + for _, d := range m.cache.ListDevices() { + if strings.HasPrefix(d, kind+"=") { + out = append(out, d) + break + } + } + case "*": + // all devices of kind if the name is a wildcard + for _, d := range m.cache.ListDevices() { + if strings.HasPrefix(d, kind+"=") { + out = append(out, d) + } + } + default: + // the specified device + for _, d := range m.cache.ListDevices() { + if d == dev.Name { + out = append(out, d) + break + } + } + if len(out) == 0 { + // check class annotation if name unknown + for _, d := range m.cache.ListDevices() { + if !strings.HasPrefix(d, kind+"=") { + continue + } + if dd := m.cache.GetDevice(d).Device; dd != nil { + if class, ok := dd.Annotations[deviceAnnotationClass]; ok && class == name { + out = append(out, d) + } + } + } + } + } + + if len(out) == 0 { + if !dev.Optional { + return nil, errors.Errorf("required device %q is not registered", dev.Name) + } + bklog.G(context.TODO()).Warnf("Optional device %q is not registered", dev.Name) + } + return out, nil +} + func (m *Manager) hasDevice(name string) bool { for _, d := range m.cache.ListDevices() { kind, _, _ := strings.Cut(d, "=")