Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write all tables to a single FITS/HDF5 file #425

Merged
merged 14 commits into from
Feb 24, 2021
Merged
1 change: 1 addition & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Installing using pip or conda will automatically install or update these core
dependencies if necessary. SkyPy also has a number of optional dependencies
that enable additional features:

- `h5py <https://www.h5py.org/>`_
- `speclite <https://speclite.readthedocs.io/>`_

To install SkyPy with all optional dependencies using pip:
Expand Down
2 changes: 1 addition & 1 deletion docs/pipeline/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fits files:

.. code-block:: bash

$ skypy examples/galaxies/sdss_photometry.yml --format fits
$ skypy examples/galaxies/sdss_photometry.yml sdss_photometry.fits

Config files are written in YAML format and read using the
`~skypy.pipeline.load_skypy_yaml` funciton. Each entry in the config specifices
Expand Down
2 changes: 1 addition & 1 deletion examples/galaxies/plot_photometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#
# .. code-block:: bash
#
# $ skypy examples/galaxies/sdss_photometry.yml --format fits
# $ skypy examples/galaxies/sdss_photometry.yml sdss_photometry.fits
#
# or in a python script using the :class:`Pipeline <skypy.pipeline.Pipeline>`
# class as demonstrated in the `SDSS Photometry`_ section below. For more
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ test =
pytest-rerunfailures
speclite>=0.11
all =
h5py
speclite>=0.11
docs =
sphinx-astropy
Expand Down
63 changes: 51 additions & 12 deletions skypy/pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._config import load_skypy_yaml
from ._items import Item, Call, Ref
import networkx
import pathlib


__all__ = [
Expand Down Expand Up @@ -186,27 +187,65 @@ def execute(self, parameters={}):
# Single column assignment
self.state[table][column] = self.evaluate(settings)

def write(self, file_format=None, overwrite=False):
def write(self, filename, overwrite=False):
r'''Write pipeline results to disk.

Parameters
----------
file_format : str
File format used to write tables. Files are written using the
Astropy unified file read/write interface; see [1]_ for supported
file formats. If None (default) tables are not written to file.
filename : str
Name of output file to be written. It must have one of the
supported file extensions for FITS (.fit .fits .fts) or HDF5
(.hdf5 .hd5 .he5 .h5).
overwrite : bool
Whether to overwrite any existing files without warning.
If filename already exists, this flag indicates whether or not to
overwrite it (without warning).
'''

References
suffix = pathlib.Path(filename).suffix.lower()
_fits_suffixes = ('.fit', '.fits', '.fts')
_hdf5_suffixes = ('.hdf5', '.hd5', '.he5', '.h5')

if suffix in _fits_suffixes:
self.write_fits(filename, overwrite)
elif suffix in _hdf5_suffixes:
self.write_hdf5(filename, overwrite)
else:
raise ValueError(f'{suffix} is an unsupported file format. SkyPy supports '
'FITS (' + ' '.join(_fits_suffixes) + ') and '
'HDF5 (' + ' '.join(_hdf5_suffixes) + ').')

def write_fits(self, filename, overwrite=False):
r'''Write pipeline results to a FITS file.

Parameters
----------
.. [1] https://docs.astropy.org/en/stable/io/unified.html
filename : str
Name of output file to be written.
overwrite : bool
If filename already exists, this flag indicates whether or not to
overwrite it (without warning).
'''
from astropy.io.fits import HDUList, PrimaryHDU, table_to_hdu
hdul = [PrimaryHDU()]
for t in self.table_config:
hdu = table_to_hdu(self[t])
hdu.header['EXTNAME'] = t
hdul.append(hdu)
HDUList(hdul).writeto(filename, overwrite=overwrite)

def write_hdf5(self, filename, overwrite=False):
r'''Write pipeline results to a HDF5 file.

Parameters
----------
filename : str
Name of output file to be written.
overwrite : bool
If filename already exists, this flag indicates whether or not to
overwrite it (without warning).
'''
if file_format:
for table in self.table_config.keys():
filename = '.'.join((table, file_format))
self.state[table].write(filename, overwrite=overwrite)
for t in self.table_config:
self[t].write(filename, path=f'tables/{t}', append=True, overwrite=overwrite)

def evaluate(self, value):
'''evaluate an item in the pipeline'''
Expand Down
15 changes: 10 additions & 5 deletions skypy/pipeline/scripts/skypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def main(args=None):
parser = argparse.ArgumentParser(description="SkyPy pipeline driver")
parser.add_argument('--version', action='version', version=skypy_version)
parser.add_argument('config', help='Config file name')
parser.add_argument('-f', '--format', required=False,
choices=['fits', 'hdf5'], help='Table file format')
parser.add_argument('output', help='Output file name')
parser.add_argument('-o', '--overwrite', action='store_true',
help='Whether to overwrite existing files')

Expand All @@ -23,7 +22,13 @@ def main(args=None):
args = parser.parse_args(args or ['--help'])
config = load_skypy_yaml(args.config)

pipeline = Pipeline(config)
pipeline.execute()
pipeline.write(file_format=args.format, overwrite=args.overwrite)
try:
pipeline = Pipeline(config)
pipeline.execute()
if args.output:
pipeline.write(args.output, overwrite=args.overwrite)
except Exception as e:
print(e)
raise SystemExit(2) from e

return(0)
50 changes: 43 additions & 7 deletions skypy/pipeline/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from astropy.cosmology import FlatLambdaCDM, default_cosmology
from astropy.cosmology.core import Cosmology
from astropy.io import fits
from astropy.io.misc.hdf5 import read_table_hdf5
from astropy.table import Table
from astropy.table.column import Column
from astropy.units import Quantity
Expand All @@ -12,6 +13,13 @@
from skypy.pipeline import Pipeline
from skypy.pipeline._items import Call, Ref

try:
import h5py # noqa
except ImportError:
HAS_H5PY = False
else:
HAS_H5PY = True


def test_pipeline():

Expand All @@ -36,17 +44,22 @@ def test_pipeline():

pipeline = Pipeline(config)
pipeline.execute()
pipeline.write(file_format='fits')
output_filename = 'output.fits'
pipeline.write(output_filename)
assert len(pipeline['test_table']) == size
assert np.all(pipeline['test_table.column1'] < pipeline['test_table.column2'])
with fits.open('test_table.fits') as hdu:
assert np.all(Table(hdu[1].data) == pipeline['test_table'])
with fits.open(output_filename) as hdu:
assert np.all(Table(hdu['test_table'].data) == pipeline['test_table'])

# Test invalid file extension
with pytest.raises(ValueError):
pipeline.write('output.invalid')

# Check for failure if output files already exist and overwrite is False
pipeline = Pipeline(config)
pipeline.execute()
with pytest.raises(OSError):
pipeline.write(file_format='fits', overwrite=False)
pipeline.write(output_filename, overwrite=False)

# Check that the existing output files are modified if overwrite is True
new_size = 2 * size
Expand All @@ -55,8 +68,8 @@ def test_pipeline():
config['tables']['test_table']['column3'].args = [new_string]
pipeline = Pipeline(config)
pipeline.execute()
pipeline.write(file_format='fits', overwrite=True)
with fits.open('test_table.fits') as hdu:
pipeline.write(output_filename, overwrite=True)
with fits.open(output_filename) as hdu:
assert len(hdu[1].data) == new_size

# Check for failure if 'column1' requires itself creating a cyclic
Expand Down Expand Up @@ -237,7 +250,30 @@ def value_in_cm(q):
np.testing.assert_array_less(pipeline['test_table.lengths_in_cm'], 100)


@pytest.mark.skipif(not HAS_H5PY, reason='Requires h5py')
def test_hdf5():
size = 100
string = size*'a'
config = {'tables': {
'test_table': {
'column1': Call(np.random.uniform, [], {
'size': size}),
'column2': Call(np.random.uniform, [], {
'low': Ref('test_table.column1')}),
'column3': Call(list, [string], {})}}}

pipeline = Pipeline(config)
pipeline.execute()
pipeline.write('output.hdf5')
hdf_table = read_table_hdf5('output.hdf5', 'tables/test_table', character_as_bytes=False)
assert np.all(hdf_table == pipeline['test_table'])


def teardown_module(module):

# Remove fits file generated in test_pipeline
os.remove('test_table.fits')
os.remove('output.fits')

# Remove hdf5 file generated in test_hdf5
if HAS_H5PY:
os.remove('output.hdf5')
26 changes: 17 additions & 9 deletions skypy/pipeline/tests/test_skypy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from astropy.utils.data import get_pkg_data_filename
from contextlib import redirect_stdout
from io import StringIO
import os
import pytest
from skypy import __version__ as skypy_version
from skypy.pipeline.scripts import skypy
Expand All @@ -26,20 +27,27 @@ def test_skypy():
assert version.getvalue().strip() == skypy_version
assert e.value.code == 0

# Missing positional argument 'config'
# Missing positional argument 'output'
with pytest.raises(SystemExit) as e:
skypy.main(['--format', 'fits'])
assert e.value.code == 2

# Invalid file format
with pytest.raises(SystemExit) as e:
skypy.main(['--format', 'invalid', 'config.filename'])
skypy.main(['config.filename'])
assert e.value.code == 2

# Process empty config file
filename = get_pkg_data_filename('data/empty_config.yml')
assert skypy.main([filename]) == 0
assert skypy.main([filename, 'empty.fits']) == 0

# Process test config file
filename = get_pkg_data_filename('data/test_config.yml')
assert skypy.main([filename]) == 0
assert skypy.main([filename, 'test.fits']) == 0

# Invalid file format
with pytest.raises(SystemExit) as e:
skypy.main([filename, 'test.invalid'])
assert e.value.code == 2


def teardown_module(module):

# Remove fits file generated in test_skypy
os.remove('empty.fits')
os.remove('test.fits')