Skip to content

Commit

Permalink
Add file getting, sending to cli.py
Browse files Browse the repository at this point in the history
Convert device connection into a separate function

Upgrade file sending to be in line with file getting

regulate spacing between functions to two lines (there's been a mix of 1 and 2 lines so far)

Remove leftover comment

Add docstrings, typehinting
  • Loading branch information
SteveMicroNova committed Nov 1, 2024
1 parent ad189ce commit 36774e0
Showing 1 changed file with 127 additions and 34 deletions.
161 changes: 127 additions & 34 deletions admin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import logging

from os import getenv
from os import getenv, path, makedirs, walk
from uuid import UUID
from time import sleep
from typing import Optional
Expand Down Expand Up @@ -69,6 +69,46 @@ def get_tunnel(tunnel_id) -> Optional[dict]:
return None


def connect_tunnel(c, tunnel_id):
"""Connects to a live support tunnel using the tunnel_id, and returns a connection to the device"""
# Care should be exercised here; we're taking data from a remote source and using it to
# run shell commands. Validate every last bit of data.
t = get_tunnel(tunnel_id)
assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected"
dip = device_ip(IPv4Network(t['network'])).ip
assert t['support_user'].isalnum()
assert t['support_user'].isascii()
support_user = t['support_user']

# set up local
c.run("gcloud compute config-ssh", hide="both")
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username']

# set up connection to bastion
ts = Connection(
host = str(get_ts_instance_public_ip(tunnel_id)),
user = ts_user,
connect_kwargs={"auth_timeout": 120}
)

# grab the ssh private key on the tunnel server
ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both")
assert ssh_privkey

# set up connection to destination device
device = Connection(
host=str(dip),
user=support_user,
gateway=ts,
connect_kwargs = {
"pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)),
}
)

return device


@task
def show(c, tunnel_id):
""" Show a single tunnel's details """
Expand All @@ -84,6 +124,7 @@ def list(c):
res.raise_for_status()
print(res.text)


@task
def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[WireguardKey] = None):
""" Create a tunnel server.
Expand Down Expand Up @@ -143,7 +184,7 @@ def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[Wiregua
connect_kwargs={"auth_timeout": 120} # long for 2FA
)

# things get hacky when being concerned with local ssh keys and all -
# things get hacky when being concerned with local ssh keys and all -
# the below configures things to "just work", every time.
c.run("gcloud compute config-ssh", hide="both")
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
Expand Down Expand Up @@ -226,46 +267,98 @@ def stop(c, tunnel_id):
# being lazy and overzealous at the same time - we'll just garbage-college its resources.
gc(c)


@task
def connect(c, tunnel_id, command="/bin/bash", pty=True):
""" Connect to a remote device, identified by a tunnel. """
# Care should be exercised here; we're taking data from a remote source and using it to
# run shell commands. Validate every last bit of data.
t = get_tunnel(tunnel_id)
assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected"
dip = device_ip(IPv4Network(t['network'])).ip
assert t['support_user'].isalnum()
assert t['support_user'].isascii()
support_user = t['support_user']
device = connect_tunnel(c, tunnel_id)
# and finally execute a shell
device.sudo(command, pty=pty)

# set up local
c.run("gcloud compute config-ssh", hide="both")
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username']

# set up connection to bastion
ts = Connection(
host = str(get_ts_instance_public_ip(tunnel_id)),
user = ts_user,
connect_kwargs={"auth_timeout": 120} # long for 2FA
)
def is_remote_directory(remote_device: Connection, remote_path: str):
"""Determines if a remote path is a directory using a bash command over a tunnel connection, returns true or false"""
result = remote_device.run(f"if [ -d '{remote_path}' ]; then echo 'directory'; else echo 'file'; fi", hide=True).stdout.strip()
return result == 'directory'

# grab the ssh private key on the tunnel server
ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both")
assert ssh_privkey

# set up connection to destination device
device = Connection(
host=str(dip),
user=support_user,
gateway=ts,
connect_kwargs = {
"pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)),
}
)
def send_file(remote_device: Connection, local_file: str, remote_dir: str):
"""Copies local file to remote directory via a tunnel connection"""
filename = path.basename(local_file)
remote_path = path.join(remote_dir, filename).replace("./", "") # path.join adds an unneccessary ./ when combining things with their own directory, making the prints down the line look odd

# Support user lacks permissions to send file to just any directory, scrape the filename and send to an intermediary and then sudo mv it to the proper location
print(f"Uploading {local_file} to {remote_path}...")
temp_remote = f"/tmp/{filename}"
remote_device.put(local=local_file, remote=temp_remote, preserve_mode=True)
remote_device.sudo(f"mv {temp_remote} {remote_path}", pty=False)
print(f"File successfully uploaded to {remote_path}")


def send_directory(remote_device: Connection, local_dir: str, remote_dir: str):
"""Recursively copies everything in a local directory to a remote directory via a tunnel connection"""
for root, _, files in walk(local_dir):
relative_root = path.relpath(root, local_dir)
remote_root = path.join(remote_dir, relative_root)

# Create the corresponding remote directory
remote_device.sudo(f"mkdir -p {remote_root}")

for file in files:
local_file_path = path.join(root, file)
send_file(remote_device, local_file_path, remote_root)
print("") # Empty print to break up "sending file", "file sent" prints so full directory transfers more human readable


@task
def sfile(c, tunnel_id, local, remote):
""" Connect to a remote device and upload a local file or directory to the remote directory.
Requests are formatted as tunnel_id, local file/directory, and remote directory for installation. """
device = connect_tunnel(c, tunnel_id)

# Send either a single file or an entire directory
if path.isdir(local):
send_directory(device, local, remote)
else:
send_file(device, local, remote)


def get_file(remote_device: Connection, local_dir: str, remote_file: str):
"""Copies a remote file to a local directory using a tunnel connection"""
filename = path.basename(remote_file)
local_path = path.join(local_dir, filename)

print(f"Downloading {remote_file} to {local_path}...")
remote_device.get(remote=remote_file, local=local_path, preserve_mode=True)
print(f"File successfully downloaded to {local_path}.")


def get_directory(remote_device, remote_dir, local_dir):
"""Recursively copies everything in a remote directory to the local directory using a tunnel connection"""
# List all files and directories under the remote directory
result = remote_device.run(f"find {remote_dir} -type d -or -type f", hide=True).stdout.splitlines()

for item in result:
relative_path = path.relpath(item, remote_dir)
local_item_path = path.join(local_dir, relative_path).replace("\\", "/")

if is_remote_directory(remote_device, item):
# Create the corresponding local directory
makedirs(local_item_path, exist_ok=True)
else:
get_file(remote_device, item, path.dirname(local_item_path))


@task
def gfile(c, tunnel_id, remote, local):
""" Connect to a remote device and download a remote file or directory to the local directory.
Requests are formatted as tunnel_id, remote file/directory, and local directory for installation. """
device = connect_tunnel(c, tunnel_id)
if is_remote_directory(device, remote):
get_directory(device, remote, local)
else:
get_file(device, remote, local)

# and finally execute a shell
device.sudo(command, pty=pty)

@task
def command(c, tunnel_id, command):
Expand Down

0 comments on commit 36774e0

Please sign in to comment.