diff --git a/powersimdata/scenario/scenario.py b/powersimdata/scenario/scenario.py index 277a47392..44a969312 100644 --- a/powersimdata/scenario/scenario.py +++ b/powersimdata/scenario/scenario.py @@ -58,9 +58,6 @@ def err_message(table, text): :param pandas.DataFrame table: scenario table. :param str text: message to print. """ - print("------------------") - print(text) - print("------------------") print( table.to_string( index=False, @@ -77,6 +74,9 @@ def err_message(table, text): ], ) ) + print("------------------") + print(text) + print("------------------") try: int(descriptor) diff --git a/powersimdata/utility/helpers.py b/powersimdata/utility/helpers.py index 935006b58..05175a6d9 100644 --- a/powersimdata/utility/helpers.py +++ b/powersimdata/utility/helpers.py @@ -4,6 +4,25 @@ import sys +class CommandBuilder: + @staticmethod + def copy(src, dest, recursive=False, update=False): + r_flag = "R" if recursive else "" + u_flag = "u" if update else "" + p_flag = "p" + flags = f"-{r_flag}{u_flag}{p_flag}" + return fr"\cp {flags} {src} {dest}" + + @staticmethod + def remove(target, recursive=False, force=False): + r_flag = "r" if recursive else "" + f_flag = "f" if force else "" + if recursive or force: + flags = f"-{r_flag}{f_flag}" + return f"rm {flags} {target}" + return f"rm {target}" + + class MemoryCache: """Wrapper around a dict object that exposes a cache interface. Users should create a separate instance for each distinct use case. diff --git a/powersimdata/utility/tests/test_helpers.py b/powersimdata/utility/tests/test_helpers.py index 29c0cff00..7b75438c1 100644 --- a/powersimdata/utility/tests/test_helpers.py +++ b/powersimdata/utility/tests/test_helpers.py @@ -1,6 +1,11 @@ import pytest -from powersimdata.utility.helpers import MemoryCache, PrintManager, cache_key +from powersimdata.utility.helpers import ( + CommandBuilder, + MemoryCache, + PrintManager, + cache_key, +) def test_print_is_disabled(capsys): @@ -61,3 +66,35 @@ def test_mem_cache_get_returns_copy(): obj = {"key1": 42} cache.put(key, obj) assert id(cache.get(key)) != id(obj) + + +def test_copy_command(): + expected = r"\cp -p source dest" + command = CommandBuilder.copy("source", "dest") + assert expected == command + + expected = r"\cp -Rp source dest" + command = CommandBuilder.copy("source", "dest", recursive=True) + assert expected == command + + expected = r"\cp -up source dest" + command = CommandBuilder.copy("source", "dest", update=True) + assert expected == command + + expected = r"\cp -Rup source dest" + command = CommandBuilder.copy("source", "dest", recursive=True, update=True) + assert expected == command + + +def test_remove_command(): + expected = "rm target" + command = CommandBuilder.remove("target") + assert expected == command + + expected = "rm -r target" + command = CommandBuilder.remove("target", recursive=True) + assert expected == command + + expected = "rm -rf target" + command = CommandBuilder.remove("target", recursive=True, force=True) + assert expected == command diff --git a/powersimdata/utility/transfer_data.py b/powersimdata/utility/transfer_data.py index 646fb4839..678801fb9 100644 --- a/powersimdata/utility/transfer_data.py +++ b/powersimdata/utility/transfer_data.py @@ -8,6 +8,7 @@ from tqdm import tqdm from powersimdata.utility import server_setup +from powersimdata.utility.helpers import CommandBuilder class DataAccess: @@ -80,7 +81,7 @@ class SSHDataAccess(DataAccess): _last_attempt = 0 - def __init__(self, root): + def __init__(self, root=None): """Constructor""" self._ssh = None self._retry_after = 5 @@ -141,12 +142,9 @@ def copy_from(self, file_name, from_dir=None): print(f"Transferring {file_name} from server") to_path = os.path.join(to_dir, file_name) - sftp = self.ssh.open_sftp() - try: + with self.ssh.open_sftp() as sftp: cbk, bar = progress_bar(ascii=True, unit="b", unit_scale=True) sftp.get(from_path, to_path, callback=cbk) - finally: - sftp.close() bar.close() def copy_to(self, file_name, to_dir, change_name_to=None): @@ -170,28 +168,18 @@ def copy_to(self, file_name, to_dir, change_name_to=None): raise IOError(f"{file_name} already exists in {to_dir} on server") print(f"Transferring {from_path} to server") - sftp = self.ssh.open_sftp() - try: + with self.ssh.open_sftp() as sftp: sftp.put(from_path, to_path) - finally: - sftp.close() print(f"--> Deleting {from_path} on local machine") os.remove(from_path) def copy(self, src, dest, recursive=False, update=False): - r_flag = "R" if recursive else "" - u_flag = "u" if update else "" - p_flag = "p" - flags = f"-{r_flag}{u_flag}{p_flag}" - command = f"\cp {flags} {src} {dest}" + command = CommandBuilder.copy(src, dest, recursive, update) return self.execute_command(command) def remove(self, target, recursive=False, force=False): - r_flag = "r" if recursive else "" - f_flag = "f" if force else "" - flags = f"-{r_flag}{f_flag}" - command = f"rm {flags} {target}" + command = CommandBuilder.remove(target, recursive, force) return self.execute_command(command) def execute_command(self, command):