diff --git a/plugins/inputs/nvidia_smi/nvidia_smi.go b/plugins/inputs/nvidia_smi/nvidia_smi.go index e4714b0ff37f8..38097a219a2f7 100644 --- a/plugins/inputs/nvidia_smi/nvidia_smi.go +++ b/plugins/inputs/nvidia_smi/nvidia_smi.go @@ -31,8 +31,9 @@ type NvidiaSMI struct { Timeout config.Duration `toml:"timeout"` Log telegraf.Logger `toml:"-"` - ignorePlugin bool - once sync.Once + nvidiaSMIArgs []string + ignorePlugin bool + once sync.Once } func (*NvidiaSMI) SampleConfig() string { @@ -53,6 +54,15 @@ func (smi *NvidiaSMI) Start(telegraf.Accumulator) error { func (*NvidiaSMI) Stop() {} +func (smi *NvidiaSMI) Probe() error { + // Construct and execute metrics query + _, err := internal.CombinedOutputTimeout(exec.Command(smi.BinPath, smi.nvidiaSMIArgs...), time.Duration(smi.Timeout)) + if err != nil { + return fmt.Errorf("calling %q failed: %w", smi.BinPath, err) + } + return nil +} + // Gather implements the telegraf interface func (smi *NvidiaSMI) Gather(acc telegraf.Accumulator) error { if smi.ignorePlugin { @@ -60,7 +70,7 @@ func (smi *NvidiaSMI) Gather(acc telegraf.Accumulator) error { } // Construct and execute metrics query - data, err := internal.CombinedOutputTimeout(exec.Command(smi.BinPath, "-q", "-x"), time.Duration(smi.Timeout)) + data, err := internal.CombinedOutputTimeout(exec.Command(smi.BinPath, smi.nvidiaSMIArgs...), time.Duration(smi.Timeout)) if err != nil { return fmt.Errorf("calling %q failed: %w", smi.BinPath, err) } @@ -119,8 +129,9 @@ func (smi *NvidiaSMI) parse(acc telegraf.Accumulator, data []byte) error { func init() { inputs.Add("nvidia_smi", func() telegraf.Input { return &NvidiaSMI{ - BinPath: "/usr/bin/nvidia-smi", - Timeout: config.Duration(5 * time.Second), + BinPath: "/usr/bin/nvidia-smi", + Timeout: config.Duration(5 * time.Second), + nvidiaSMIArgs: []string{"-q", "-x"}, } }) } diff --git a/plugins/inputs/nvidia_smi/nvidia_smi_test.go b/plugins/inputs/nvidia_smi/nvidia_smi_test.go index 23c57e5f6a3b6..6cd0ff7f771ba 100644 --- a/plugins/inputs/nvidia_smi/nvidia_smi_test.go +++ b/plugins/inputs/nvidia_smi/nvidia_smi_test.go @@ -4,16 +4,66 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "time" "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/internal" "github.com/influxdata/telegraf/models" "github.com/influxdata/telegraf/testutil" "github.com/stretchr/testify/require" ) +func TestProbe(t *testing.T) { + var binPath string + var nvidiaSMIArgsPrefix []string + if runtime.GOOS == "windows" { + binPath = `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` + nvidiaSMIArgsPrefix = []string{"-Command"} + } else { + binPath = "/bin/bash" + nvidiaSMIArgsPrefix = []string{"-c"} + } + + for _, tt := range []struct { + name string + args string + expectError bool + }{ + { + name: "probe success", + args: "exit 0", + expectError: false, + }, + { + name: "probe error", + args: "exit 1", + expectError: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + plugin := &NvidiaSMI{ + BinPath: binPath, + nvidiaSMIArgs: append(nvidiaSMIArgsPrefix, tt.args), + Log: &testutil.Logger{}, + Timeout: config.Duration(5 * time.Second), + } + model := models.NewRunningInput(plugin, &models.InputConfig{ + Name: "nvidia_smi", + StartupErrorBehavior: "probe", + }) + err := model.Probe() + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func TestErrorBehaviorDefault(t *testing.T) { // make sure we can't find nvidia-smi in $PATH somewhere os.Unsetenv("PATH")