Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task API via Unix Domain Socket #15864

Merged
merged 18 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/15864.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
client: added http api access for tasks via unix socket
```
5 changes: 3 additions & 2 deletions client/allocrunner/interfaces/task_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ import (
+-----------+
*Kill
(forces terminal)

Link: http://stable.ascii-flow.appspot.com/#Draw4489375405966393064/1824429135
*/

// TaskHook is a lifecycle hook into the life cycle of a task runner.
Expand Down Expand Up @@ -186,6 +184,9 @@ type TaskStopRequest struct {
// ExistingState is previously set hook data and should only be
// read. Stop hooks cannot alter state.
ExistingState map[string]string

// TaskDir contains the task's directory tree on the host
TaskDir *allocdir.TaskDir
}

type TaskStopResponse struct{}
Expand Down
119 changes: 119 additions & 0 deletions client/allocrunner/taskrunner/api_hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package taskrunner

import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"sync"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/helper/users"
)

// apiHook exposes the Task API. The Task API allows task's to access the Nomad
// HTTP API without having to discover and connect to an agent's address.
// Instead a unix socket is provided in a standard location. To prevent access
// by untrusted workloads the Task API always requires authentication even when
// ACLs are disabled.
//
// The Task API hook largely soft-fails as there are a number of ways creating
// the unix socket could fail (the most common one being path length
// restrictions), and it is assumed most tasks won't require access to the Task
// API anyway. Tasks that do require access are expected to crash and get
// rescheduled should they land on a client who Task API hook soft-fails.
type apiHook struct {
shutdownCtx context.Context
srv config.APIListenerRegistrar
logger hclog.Logger

// Lock listener as it is updated from multiple hooks.
lock sync.Mutex

// Listener is the unix domain socket of the task api for this taks.
ln net.Listener
}

func newAPIHook(shutdownCtx context.Context, srv config.APIListenerRegistrar, logger hclog.Logger) *apiHook {
h := &apiHook{
shutdownCtx: shutdownCtx,
srv: srv,
}
h.logger = logger.Named(h.Name())
return h
}

func (*apiHook) Name() string {
return "api"
}

func (h *apiHook) Prestart(_ context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
h.lock.Lock()
defer h.lock.Unlock()

if h.ln != nil {
// Listener already set. Task is probably restarting.
return nil
}

udsPath := apiSocketPath(req.TaskDir)
udsln, err := users.SocketFileFor(h.logger, udsPath, req.Task.User)
if err != nil {
// Soft-fail and let the task fail if it requires the task api.
h.logger.Warn("error creating task api socket", "path", udsPath, "error", err)
return nil
}

go func() {
// Cannot use Prestart's context as it is closed after all prestart hooks
// have been closed, but we do want to try to cleanup on shutdown.
if err := h.srv.Serve(h.shutdownCtx, udsln); err != nil {
if errors.Is(err, http.ErrServerClosed) {
return
}
if errors.Is(err, net.ErrClosed) {
return
}
h.logger.Error("error serving task api", "error", err)
}
}()

h.ln = udsln
return nil
}

func (h *apiHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
h.lock.Lock()
defer h.lock.Unlock()

if h.ln != nil {
if err := h.ln.Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
h.logger.Debug("error closing task listener: %v", err)
}
}
h.ln = nil
}

// Best-effort at cleaining things up. Alloc dir cleanup will remove it if
// this fails for any reason.
_ = os.RemoveAll(apiSocketPath(req.TaskDir))

return nil
}

// apiSocketPath returns the path to the Task API socket.
//
// The path needs to be as short as possible because of the low limits on the
// sun_path char array imposed by the syscall used to create unix sockets.
//
// See https://github.com/hashicorp/nomad/pull/13971 for an example of the
// sadness this causes.
func apiSocketPath(taskDir *allocdir.TaskDir) string {
return filepath.Join(taskDir.SecretsDir, "api.sock")
}
169 changes: 169 additions & 0 deletions client/allocrunner/taskrunner/api_hook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package taskrunner

import (
"context"
"io/fs"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"testing"

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/users"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
)

type testAPIListenerRegistrar struct {
cb func(net.Listener) error
}

func (n testAPIListenerRegistrar) Serve(_ context.Context, ln net.Listener) error {
if n.cb != nil {
return n.cb(ln)
}
return nil
}

// TestAPIHook_SoftFail asserts that the Task API Hook soft fails and does not
// return errors.
func TestAPIHook_SoftFail(t *testing.T) {
ci.Parallel(t)

// Use a SecretsDir that will always exceed Unix socket path length
// limits (sun_path)
dst := filepath.Join(t.TempDir(), strings.Repeat("_NOMAD_TEST_", 100))

ctx := context.Background()
srv := testAPIListenerRegistrar{}
logger := testlog.HCLogger(t)
h := newAPIHook(ctx, srv, logger)

req := &interfaces.TaskPrestartRequest{
Task: &structs.Task{}, // needs to be non-nil for Task.User lookup
TaskDir: &allocdir.TaskDir{
SecretsDir: dst,
},
}
resp := &interfaces.TaskPrestartResponse{}

err := h.Prestart(ctx, req, resp)
must.NoError(t, err)

// listener should not have been set
must.Nil(t, h.ln)

// File should not have been created
_, err = os.Stat(dst)
must.Error(t, err)

// Assert stop also soft-fails
stopReq := &interfaces.TaskStopRequest{
TaskDir: req.TaskDir,
}
stopResp := &interfaces.TaskStopResponse{}
err = h.Stop(ctx, stopReq, stopResp)
must.NoError(t, err)

// File should not have been created
_, err = os.Stat(dst)
must.Error(t, err)
}

// TestAPIHook_Ok asserts that the Task API Hook creates and cleans up a
// socket.
func TestAPIHook_Ok(t *testing.T) {
ci.Parallel(t)

// If this test fails it may be because TempDir() + /api.sock is longer than
// the unix socket path length limit (sun_path) in which case the test should
// use a different temporary directory on that platform.
dst := t.TempDir()

// Write "ok" and close the connection and listener
srv := testAPIListenerRegistrar{
cb: func(ln net.Listener) error {
conn, err := ln.Accept()
if err != nil {
return err
}
if _, err = conn.Write([]byte("ok")); err != nil {
return err
}
conn.Close()
return nil
},
}

ctx := context.Background()
logger := testlog.HCLogger(t)
h := newAPIHook(ctx, srv, logger)

req := &interfaces.TaskPrestartRequest{
Task: &structs.Task{
User: "nobody",
},
TaskDir: &allocdir.TaskDir{
SecretsDir: dst,
},
}
resp := &interfaces.TaskPrestartResponse{}

err := h.Prestart(ctx, req, resp)
must.NoError(t, err)

// File should have been created
sockDst := apiSocketPath(req.TaskDir)

// Stat and chown fail on Windows, so skip these checks
if runtime.GOOS != "windows" {
stat, err := os.Stat(sockDst)
must.NoError(t, err)
must.True(t, stat.Mode()&fs.ModeSocket != 0,
must.Sprintf("expected %q to be a unix socket but got %s", sockDst, stat.Mode()))

nobody, _ := users.Lookup("nobody")
if syscall.Getuid() == 0 && nobody != nil {
t.Logf("root and nobody exists: testing file perms")

// We're root and nobody exists! Check perms
must.Eq(t, fs.FileMode(0o600), stat.Mode().Perm())

sysStat, ok := stat.Sys().(*syscall.Stat_t)
must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t on %s but found %T",
runtime.GOOS, stat.Sys()))

nobodyUID, err := strconv.Atoi(nobody.Uid)
must.NoError(t, err)
must.Eq(t, nobodyUID, int(sysStat.Uid))
}
}

// Assert the listener is working
conn, err := net.Dial("unix", sockDst)
must.NoError(t, err)
buf := make([]byte, 2)
_, err = conn.Read(buf)
must.NoError(t, err)
must.Eq(t, []byte("ok"), buf)
conn.Close()

// Assert stop cleans up
stopReq := &interfaces.TaskStopRequest{
TaskDir: req.TaskDir,
}
stopResp := &interfaces.TaskStopResponse{}
err = h.Stop(ctx, stopReq, stopResp)
must.NoError(t, err)

// File should be gone
_, err = net.Dial("unix", sockDst)
must.Error(t, err)
}
5 changes: 4 additions & 1 deletion client/allocrunner/taskrunner/task_runner_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (tr *TaskRunner) initHooks() {
newArtifactHook(tr, tr.getter, hookLogger),
newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
newDeviceHook(tr.devicemanager, hookLogger),
newAPIHook(tr.shutdownCtx, tr.clientConfig.APIListenerRegistrar, hookLogger),
}

// If the task has a CSI block, add the hook.
Expand Down Expand Up @@ -431,7 +432,9 @@ func (tr *TaskRunner) stop() error {
tr.logger.Trace("running stop hook", "name", name, "start", start)
}

req := interfaces.TaskStopRequest{}
req := interfaces.TaskStopRequest{
TaskDir: tr.taskDir,
}

origHookState := tr.hookState(name)
if origHookState != nil {
Expand Down
19 changes: 19 additions & 0 deletions client/config/config.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package config

import (
"context"
"errors"
"fmt"
"net"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -301,10 +303,27 @@ type Config struct {
// used for template functions which require access to the Nomad API.
TemplateDialer *bufconndialer.BufConnWrapper

// APIListenerRegistrar allows the client to register listeners created at
// runtime (eg the Task API) with the agent's HTTP server. Since the agent
// creates the HTTP *after* the client starts, we have to use this shim to
// pass listeners back to the agent.
// This is the same design as the bufconndialer but for the
// http.Serve(listener) API instead of the net.Dial API.
APIListenerRegistrar APIListenerRegistrar

// Artifact configuration from the agent's config file.
Artifact *ArtifactConfig
}

type APIListenerRegistrar interface {
// Serve the HTTP API on the provided listener.
//
// The context is because Serve may be called before the HTTP server has been
// initialized. If the context is canceled before the HTTP server is
// initialized, the context's error will be returned.
Serve(context.Context, net.Listener) error
}

// ClientTemplateConfig is configuration on the client specific to template
// rendering
type ClientTemplateConfig struct {
Expand Down
11 changes: 11 additions & 0 deletions client/config/testing.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package config

import (
"context"
"io/ioutil"
"net"
"os"
"path/filepath"
"time"
Expand Down Expand Up @@ -74,5 +76,14 @@ func TestClientConfig(t testing.T) (*Config, func()) {
// Same as default; necessary for task Event messages
conf.MaxKillTimeout = 30 * time.Second

// Provide a stub APIListenerRegistrar implementation
conf.APIListenerRegistrar = NoopAPIListenerRegistrar{}

return conf, cleanup
}

type NoopAPIListenerRegistrar struct{}

func (NoopAPIListenerRegistrar) Serve(_ context.Context, _ net.Listener) error {
return nil
}
Loading