diff --git a/powersimdata/data_access/csv_store.py b/powersimdata/data_access/csv_store.py index c6d3c6da1..93eb1b7ac 100644 --- a/powersimdata/data_access/csv_store.py +++ b/powersimdata/data_access/csv_store.py @@ -61,7 +61,8 @@ def get_table(self): def _get_table(self, filename): self.data_access.copy_from(filename) - return self.data_access.read(filename, callback=lambda f, _: _parse_csv(f)) + with self.data_access.get(filename) as (f, _): + return _parse_csv(f) def commit(self, table, checksum): """Save to local directory and upload if needed diff --git a/powersimdata/data_access/data_access.py b/powersimdata/data_access/data_access.py index 2d467cfb8..aead8d2bf 100644 --- a/powersimdata/data_access/data_access.py +++ b/powersimdata/data_access/data_access.py @@ -1,10 +1,9 @@ -import os import pickle import posixpath +from contextlib import contextmanager from subprocess import Popen import fs -import pandas as pd from fs.multifs import MultiFS from fs.path import basename, dirname from fs.tempfs import TempFS @@ -32,6 +31,10 @@ def get_ssh_fs(root=""): def get_multi_fs(root): + """Create filesystem combining the server (if connected) with profile and scenario + containers in blob storage. The priority is in descending order, so the server will + be used first if possible + """ scenario_data = get_blob_fs("scenariodata") profiles = get_blob_fs("profiles") mfs = MultiFS() @@ -54,74 +57,64 @@ def __init__(self, root): self.join = fs.path.join self.local_fs = None - def read(self, filepath, callback=None): - """Reads data from data store. + @contextmanager + def get(self, filepath): + """Copy file from remote filesystem if needed and read into memory - :param str filepath: path to file, with extension either 'pkl', 'csv', or 'mat'. - :return: (*pandas.DataFrame* or *dict*) -- pkl and csv files will be returned as - a data frame, while a mat file will be returned as a dictionary - :raises ValueError: if extension is unknown. + :param str filepath: path to file + :return: (*tuple*) -- file object and filepath to be handled by caller """ if not self.local_fs.exists(filepath): print(f"{filepath} not found on local machine") from_dir, filename = dirname(filepath), basename(filepath) self.copy_from(filename, from_dir) - if callback is None: - callback = self._read with self.local_fs.openbin(filepath) as f: - return callback(f, filepath) - - def _read(self, f, filepath): - ext = os.path.basename(filepath).split(".")[-1] - if ext == "pkl": - data = pd.read_pickle(f) - elif ext == "csv": - data = pd.read_csv(f, index_col=0, parse_dates=True) - data.columns = data.columns.astype(int) - elif ext == "mat": - # get fully qualified local path to matfile - data = self.local_fs.getsyspath(filepath) - else: - raise ValueError("Unknown extension! %s" % ext) - - return data + filepath = self.local_fs.getsyspath(filepath) + yield f, filepath - def write(self, filepath, data, save_local=True): + def write(self, filepath, data, save_local=True, callback=None): """Write a file to data store. - :param str filepath: path to save data to, with extension either 'pkl', 'csv', or 'mat'. - :param (*pandas.DataFrame* or *dict*) data: data to save + :param str filepath: path to save data to + :param object data: data to save :param bool save_local: whether a copy should also be saved to the local filesystem, if such a filesystem is configured. Defaults to True. + :param callable callback: the specific persistence implementation """ - self._check_file_exists(filepath, should_exist=False) - print("Writing %s" % filepath) - self._write(self.fs, filepath, data) + self._check_file_exists(filepath, should_exist=False, mode="w") + if callback is None: + callback = self._callback - if save_local and self.local_fs is not None: - self._write(self.local_fs, filepath, data) + print("Writing %s" % filepath) + self._write(self.fs, filepath, data, callback) + if save_local: + self._write(self.local_fs, filepath, data, callback) - def _write(self, fs, filepath, data): + def _write(self, fs, filepath, data, callback=None): """Write a file to given data store. - :param fs fs: pyfilesystem to which to write data - :param str filepath: path to save data to, with extension either 'pkl', 'csv', or 'mat'. - :param (*pandas.DataFrame* or *dict*) data: data to save + :param fs.base.FS fs: pyfilesystem to which to write data + :param str filepath: path to save data to + :param object data: data to save + :param callable callback: the specific persistence implementation :raises ValueError: if extension is unknown. """ - ext = os.path.basename(filepath).split(".")[-1] fs.makedirs(dirname(filepath), recreate=True) with fs.openbin(filepath, "w") as f: - if ext == "pkl": - pickle.dump(data, f) - elif ext == "csv": - data.to_csv(f) - elif ext == "mat": - savemat(f, data, appendmat=False) - else: - raise ValueError("Unknown extension! %s" % ext) + callback(f, filepath, data) + + def _callback(self, f, filepath, data): + ext = basename(filepath).split(".")[-1] + if ext == "pkl": + pickle.dump(data, f) + elif ext == "csv": + data.to_csv(f) + elif ext == "mat": + savemat(f, data, appendmat=False) + else: + raise ValueError("Unknown extension! %s" % ext) def copy_from(self, file_name, from_dir=None): """Copy a file from data store to userspace. @@ -174,14 +167,14 @@ def remove(self, pattern, confirm=True): self.fs.glob(pattern).remove() print("--> Done!") - def _check_file_exists(self, path, should_exist=True): + def _check_file_exists(self, path, should_exist=True, mode="r"): """Check that file exists (or not) at the given path :param str path: the relative path to the file :param bool should_exist: whether the file is expected to exist :raises OSError: if the expected condition is not met """ - location, _ = self.fs.which(path) + location, _ = self.fs.which(path, mode) exists = location is not None if should_exist and not exists: remotes = [f[0] for f in self.fs.iterate_fs()] @@ -252,8 +245,7 @@ def __init__(self, root=server_setup.DATA_ROOT_DIR): """Constructor""" super().__init__(root) self._fs = None - self.local_root = server_setup.LOCAL_DIR - self.local_fs = fs.open_fs(self.local_root) + self.local_fs = fs.open_fs(server_setup.LOCAL_DIR) @property def fs(self): @@ -286,7 +278,8 @@ def checksum(self, relative_path): """ self._check_file_exists(relative_path) full_path = self.join(self.root, relative_path) - return self.fs.checksum(full_path) + ssh_fs = self.fs.get_fs("ssh_fs") + return ssh_fs.checksum(full_path) def push(self, file_name, checksum, rename): """Push file to server and verify the checksum matches a prior value @@ -298,7 +291,7 @@ def push(self, file_name, checksum, rename): """ backup = f"{rename}.temp" - self._check_file_exists(backup, should_exist=False) + self._check_file_exists(backup, should_exist=False, mode="w") print(f"Transferring {rename} to server") fs.move.move_file(self.local_fs, file_name, self.fs, backup) @@ -332,7 +325,7 @@ class MemoryDataAccess(SSHDataAccess): def __init__(self): self.local_fs = fs.open_fs("mem://") self._fs = self._get_fs() - self.local_root = self.root = "dummy" + self.root = "foo" self.join = fs.path.join def _get_fs(self): diff --git a/powersimdata/input/input_data.py b/powersimdata/input/input_data.py index 591445e1b..b9b326626 100644 --- a/powersimdata/input/input_data.py +++ b/powersimdata/input/input_data.py @@ -1,3 +1,5 @@ +import os + import pandas as pd from powersimdata.data_access.context import Context @@ -43,6 +45,29 @@ def _check_field(field_name): raise ValueError("Only %s data can be loaded" % " | ".join(possible)) +def _read(f, filepath): + """Read data from file object + + :param io.IOBase f: a file handle + :param str filepath: the filepath corresponding to f + :raises ValueError: if extension is unknown. + :return: object -- the result + """ + ext = os.path.basename(filepath).split(".")[-1] + if ext == "pkl": + data = pd.read_pickle(f) + elif ext == "csv": + data = pd.read_csv(f, index_col=0, parse_dates=True) + data.columns = data.columns.astype(int) + elif ext == "mat": + # get fully qualified local path to matfile + data = os.path.abspath(filepath) + else: + raise ValueError("Unknown extension! %s" % ext) + + return data + + class InputData: """Load input data. @@ -79,7 +104,8 @@ def get_data(self, scenario_info, field_name): cached = _cache.get(key) if cached is not None: return cached - data = self.data_access.read(filepath) + with self.data_access.get(filepath) as (f, path): + data = _read(f, path) _cache.put(key, data) return data diff --git a/powersimdata/output/output_data.py b/powersimdata/output/output_data.py index 811cbd5ce..7f1802c23 100644 --- a/powersimdata/output/output_data.py +++ b/powersimdata/output/output_data.py @@ -32,7 +32,8 @@ def get_data(self, scenario_id, field_name): print("--> Loading %s" % field_name) file_name = scenario_id + "_" + field_name + ".pkl" filepath = "/".join([*server_setup.OUTPUT_DIR, file_name]) - return self._data_access.read(filepath) + with self._data_access.get(filepath) as (f, _): + return pd.read_pickle(f) def _check_field(field_name):