Skip to content

Commit

Permalink
Enforce mTLS with inception server (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
pjbgf authored Nov 28, 2024
2 parents 259e3f9 + 817bbf0 commit 178d0c0
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 8 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ require (
golang.org/x/net v0.31.0 // indirect
golang.org/x/text v0.20.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc/security/advancedtls v1.0.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0=
google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA=
google.golang.org/grpc/security/advancedtls v1.0.0 h1:/KQ7VP/1bs53/aopk9QhuPyFAp9Dm9Ejix3lzYkCrDA=
google.golang.org/grpc/security/advancedtls v1.0.0/go.mod h1:o+s4go+e1PJ2AjuQMY5hU82W7lDlefjJA6FqEHRVHWk=
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
49 changes: 46 additions & 3 deletions internal/inception/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package inception

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"os"
"strings"
"time"

"github.com/qubesome/cli/internal/util/mtls"
pb "github.com/qubesome/cli/pkg/inception/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/credentials"
)

func NewClient(socket string) *Client {
Expand All @@ -22,8 +26,42 @@ type Client struct {
socket string
}

func getCreds() (credentials.TransportCredentials, error) {
caPEM := []byte(os.Getenv("Q_MTLS_CA"))
certPEM := []byte(os.Getenv("Q_MTLS_CERT"))
keyPEM := []byte(os.Getenv("Q_MTLS_KEY"))

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, err
}

certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caPEM) {
return nil, err
}

creds := credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: certPool,
MinVersion: tls.VersionTLS13,
// The connection is made via unix socket, so generally the
// expected server name will be localhost - unless overridden
// by ServerName.
ServerName: mtls.HostServerName,
})

return creds, nil
}

func (c *Client) XdgOpen(ctx context.Context, url string) error {
conn, err := grpc.NewClient(c.socket, grpc.WithTransportCredentials(insecure.NewCredentials()))
creds, err := getCreds()
if err != nil {
return err
}

conn, err := grpc.NewClient(c.socket,
grpc.WithTransportCredentials(creds))
if err != nil {
return fmt.Errorf("failed to connect to qubesome host: %w", err)
}
Expand All @@ -44,7 +82,12 @@ func (c *Client) XdgOpen(ctx context.Context, url string) error {
}

func (c *Client) Run(ctx context.Context, workload string, args []string) error {
conn, err := grpc.NewClient(c.socket, grpc.WithTransportCredentials(insecure.NewCredentials()))
creds, err := getCreds()
if err != nil {
return err
}

conn, err := grpc.NewClient(c.socket, grpc.WithTransportCredentials(creds))
if err != nil {
return fmt.Errorf("failed to connect to qubesome host: %w", err)
}
Expand Down
19 changes: 16 additions & 3 deletions internal/profiles/profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/qubesome/cli/internal/types"
"github.com/qubesome/cli/internal/util/dbus"
"github.com/qubesome/cli/internal/util/gpu"
"github.com/qubesome/cli/internal/util/mtls"
"github.com/qubesome/cli/internal/util/resolution"
"github.com/qubesome/cli/internal/util/xauth"
"github.com/qubesome/cli/pkg/inception"
Expand Down Expand Up @@ -250,11 +251,15 @@ func Start(runner string, profile *types.Profile, cfg *types.Config) (err error)
return err
}

creds, err := mtls.NewCredentials()
if err != nil {
return err
}
go func() {
defer wg.Done()

server := inception.NewServer(profile, cfg)
err1 := server.Listen(sockPath)
err1 := server.Listen(creds.ServerCert, creds.CA, sockPath)
if err1 != nil {
slog.Debug("error listening to socket", "error", err1)
if err == nil {
Expand All @@ -272,7 +277,9 @@ func Start(runner string, profile *types.Profile, cfg *types.Config) (err error)
return err
}

err = createNewDisplay(binary, profile, strconv.Itoa(int(profile.Display)))
err = createNewDisplay(binary,
creds.CA, creds.ClientPEM, creds.ClientKeyPEM,
profile, strconv.Itoa(int(profile.Display)))
if err != nil {
return err
}
Expand Down Expand Up @@ -356,7 +363,7 @@ func startWindowManager(bin, name, display, wm string) error {
return nil
}

func createNewDisplay(bin string, profile *types.Profile, display string) error {
func createNewDisplay(bin string, ca, cert, key []byte, profile *types.Profile, display string) error {
command := "Xephyr"
res, err := resolution.Primary()
if err != nil {
Expand Down Expand Up @@ -463,6 +470,9 @@ func createNewDisplay(bin string, profile *types.Profile, display string) error
// rely on currently set DISPLAY.
"-e", "DISPLAY",
"-e", "XDG_SESSION_TYPE=X11",
"-e", "Q_MTLS_CA",
"-e", "Q_MTLS_CERT",
"-e", "Q_MTLS_KEY",
"--device", "/dev/dri",
"--security-opt=no-new-privileges:true",
"--cap-drop=ALL",
Expand Down Expand Up @@ -546,6 +556,9 @@ func createNewDisplay(bin string, profile *types.Profile, display string) error

slog.Debug("exec: "+bin, "args", dockerArgs)
cmd := execabs.Command(bin, dockerArgs...)
cmd.Env = append(os.Environ(), "Q_MTLS_CA="+string(ca))
cmd.Env = append(cmd.Env, "Q_MTLS_CERT="+string(cert))
cmd.Env = append(cmd.Env, "Q_MTLS_KEY="+string(key))

output, err := cmd.CombinedOutput()
if err != nil {
Expand Down
159 changes: 159 additions & 0 deletions internal/util/mtls/mtls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package mtls

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"time"
)

const (
validFor = 7 * 24 * time.Hour // 7 days
// ProfileServerName sets the server name for the qubesome profile.
ProfileServerName = "qubesome-profile"
// HostServerName sets the server name for the qubesome host.
HostServerName = "qubesome-host"
)

type Credentials struct {
ServerCert tls.Certificate
CA []byte
ClientPEM []byte
ClientKeyPEM []byte
}

func NewCredentials() (*Credentials, error) {
caCert, caKey, caBytes, err := generateCA()
if err != nil {
return nil, err
}

serverCertBytes, serverKey, err := generateCert(caCert, caKey, true)
if err != nil {
return nil, err
}
serverCertPEM, serverKeyPEM, err := pemEncode(serverCertBytes, serverKey)
if err != nil {
return nil, err
}

serverCert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM)
if err != nil {
return nil, err
}

clientCertBytes, clientKey, err := generateCert(caCert, caKey, false)
if err != nil {
return nil, err
}
clientCertPEM, clientKeyPEM, err := pemEncode(clientCertBytes, clientKey)
if err != nil {
return nil, err
}

ca := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caBytes})
return &Credentials{
ServerCert: serverCert,
CA: ca,
ClientPEM: clientCertPEM,
ClientKeyPEM: clientKeyPEM,
}, nil
}

// generateCA generates an in-memory CA certificate and private key.
func generateCA() (*x509.Certificate, *ecdsa.PrivateKey, []byte, error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to generate CA private key: %w", err)
}

template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
CommonName: "qubesome inception CA",
Organization: []string{"qubesome"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(validFor),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
IsCA: true,
BasicConstraintsValid: true,
SignatureAlgorithm: x509.ECDSAWithSHA256,
}

certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create CA certificate: %w", err)
}

certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to PEM encode certificate: %w", err)
}

cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to parse CA certificate: %w", err)
}

return cert, priv, certBytes, nil
}

// generateCert generates a certificate signed by caCert.
func generateCert(caCert *x509.Certificate, caKey *ecdsa.PrivateKey, isServer bool) ([]byte, *ecdsa.PrivateKey, error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}

template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
Organization: []string{"qubesome"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(validFor),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}

if isServer {
template.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
template.DNSNames = []string{HostServerName}
} else {
template.DNSNames = []string{ProfileServerName}
}

certBytes, err := x509.CreateCertificate(rand.Reader, template, caCert, &priv.PublicKey, caKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}

return certBytes, priv, nil
}

// pemEncode encodes the certificate and private key to PEM format.
func pemEncode(certBytes []byte, priv *ecdsa.PrivateKey) ([]byte, []byte, error) {
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})

privBytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})

return certPEM, keyPEM, nil
}
21 changes: 19 additions & 2 deletions pkg/inception/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package inception

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
Expand All @@ -10,8 +12,10 @@ import (
"github.com/qubesome/cli/internal/command"
"github.com/qubesome/cli/internal/qubesome"
"github.com/qubesome/cli/internal/types"
"github.com/qubesome/cli/internal/util/mtls"
pb "github.com/qubesome/cli/pkg/inception/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// NewServer returns a new inception server.
Expand All @@ -32,13 +36,26 @@ type Server struct {
server *grpcServer
}

func (s *Server) Listen(socket string) error {
func (s *Server) Listen(serverCert tls.Certificate, ca []byte, socket string) error {
lis, err := net.Listen("unix", socket)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

gs := grpc.NewServer()
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(ca) {
return fmt.Errorf("failed to append CA from PEM")
}

creds := credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: certPool,
MinVersion: tls.VersionTLS13,
ServerName: mtls.ProfileServerName,
})

gs := grpc.NewServer(grpc.Creds(creds))
pb.RegisterQubesomeHostServer(gs, s.server)

slog.Debug("[server] listening", "addr", lis.Addr())
Expand Down

0 comments on commit 178d0c0

Please sign in to comment.