diff --git a/.travis.yml b/.travis.yml index ac277bd525..58abf99e9e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,12 +5,13 @@ python: - '2.7' - '3.6' before_install: - - pip install --upgrade pip setuptools + - pip install --upgrade pip setuptools wheel install: - - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,minuit,develop] # Ensure right version of NumPy installed + - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,minuit,develop] + - pip freeze script: - pyflakes pyhf - - pytest --ignore tests/benchmarks/ + - pytest -r sx --ignore tests/benchmarks/ --ignore tests/test_notebooks.py after_success: coveralls # always test (on both 'push' and 'pr' builds in Travis) @@ -31,21 +32,31 @@ env: jobs: include: + - name: "Python 2.7 Notebook Tests" + python: '2.7' + script: + - pytest tests/test_notebooks.py + - name: "Python 3.6 Notebook Tests" + python: '3.6' + script: + - pytest tests/test_notebooks.py - stage: benchmark python: '3.6' before_install: - - pip install --upgrade pip setuptools + - pip install --upgrade pip setuptools wheel install: - - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,develop] - script: pytest --benchmark-sort=mean tests/benchmarks/ + - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,minuit,develop] + - pip freeze + script: pytest -r sx --benchmark-sort=mean tests/benchmarks/ - stage: docs python: '3.6' before_install: - sudo apt-get update - sudo apt-get -qq install pandoc - - pip install --upgrade pip setuptools + - pip install --upgrade pip setuptools wheel install: - - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,develop] + - pip install --ignore-installed -U -q -e .[tensorflow,torch,mxnet,minuit,develop] + - pip freeze script: - python -m doctest README.md - cd docs && make html && cd - diff --git a/docs/examples/notebooks/histosys-pytorch.ipynb b/docs/examples/notebooks/histosys-pytorch.ipynb index 4578842d9e..09f75636ef 100644 --- a/docs/examples/notebooks/histosys-pytorch.ipynb +++ b/docs/examples/notebooks/histosys-pytorch.ipynb @@ -9,9 +9,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], "source": [ "%pylab inline" ] @@ -24,36 +32,23 @@ "source": [ "import pyhf\n", "from pyhf import Model\n", - "from pyhf.simplemodels import hepdata_like" + "from pyhf.simplemodels import hepdata_like\n", + "\n", + "import tensorflow as tf" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " [120.0, 180.0, 100.0, 225.0]\n", + "[120.0, 180.0, 100.0, 225.0]\n", "[[0, 10], [0, 10], [0, 10]]\n", - "['mu', 'uncorr_bkguncrt']\n", - "---\n", - "as numpy\n", - "-----\n", - " [-22.87785012]\n", - "---\n", - "as pytorch\n", - "-----\n", - " Variable containing:\n", - "-22.8778\n", - "[torch.FloatTensor of size 1]\n", - "\n", - "---\n", - "as tensorflow\n", - "-----\n", - " [-22.877851486206055]\n" + "['mu', 'uncorr_bkguncrt']\n" ] } ], @@ -76,54 +71,74 @@ "\n", "print(data)\n", "print(par_bounds)\n", - "print(pdf.config.par_order)\n", - "\n", - "\n", - "print '---\\nas numpy\\n-----'\n", - "pyhf.tensorlib = pyhf.numpy_backend(poisson_from_normal = True)\n", - "v = pdf.logpdf(init_pars,data)\n", - "print type(v),v\n", - "\n", - "print '---\\nas pytorch\\n-----'\n", - "pyhf.tensorlib = pyhf.pytorch_backend()\n", - "v = pdf.logpdf(init_pars,data)\n", - "print type(v),v\n", - "\n", - "\n", - "print '---\\nas tensorflow\\n-----'\n", - "import tensorflow as tf\n", - "pyhf.tensorlib = pyhf.tensorflow_backend()\n", - "v = pdf.logpdf(init_pars,data)\n", - "\n", - "pyhf.tensorlib.session = tf.Session()\n", - "print type(v),pyhf.tensorlib.tolist(v)" + "print(pdf.config.par_order)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "# NumPy\n", + " [-23.57960517]\n", + "\n", + "# TensorFlow\n", + " Tensor(\"mul:0\", shape=(1,), dtype=float32)\n", + "\n", + "# PyTorch\n", + " tensor([-23.5796])\n", + "\n", + "# MXNet\n", + " \n", + "[-23.57959]\n", + "\n" + ] + } + ], + "source": [ + "backends = [\n", + " pyhf.tensor.numpy_backend(),\n", + " pyhf.tensor.tensorflow_backend(session=tf.Session()),\n", + " pyhf.tensor.pytorch_backend(),\n", + " pyhf.tensor.mxnet_backend()\n", + "]\n", + "names = [\n", + " 'NumPy',\n", + " 'TensorFlow',\n", + " 'PyTorch',\n", + " 'MXNet'\n", + "]\n", + "\n", + "for backend,name in zip(backends,names):\n", + " print('\\n# {name}'.format(name=name))\n", + " pyhf.set_backend(backend)\n", + " v = pdf.logpdf(init_pars,data)\n", + " print(type(v), v)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.14" + "pygments_lexer": "ipython3", + "version": "3.6.6" } }, "nbformat": 4, diff --git a/pyhf/__init__.py b/pyhf/__init__.py index 6ea70b0ac9..a68dfd37d2 100644 --- a/pyhf/__init__.py +++ b/pyhf/__init__.py @@ -4,38 +4,39 @@ optimizer = optimize.scipy_optimizer() default_optimizer = optimizer + def get_backend(): """ Get the current backend and the associated optimizer - Returns: - backend, optimizer - Example: >>> import pyhf >>> pyhf.get_backend() (, ) + Returns: + backend, optimizer """ global tensorlib global optimizer return tensorlib, optimizer -def set_backend(backend, custom_optimizer = None): + +def set_backend(backend, custom_optimizer=None): """ Set the backend and the associated optimizer + Example: + >>> import pyhf + >>> import tensorflow as tf + >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend(session=tf.Session())) + Args: backend: One of the supported pyhf backends: NumPy, TensorFlow, PyTorch, and MXNet Returns: None - - Example: - >>> import pyhf - >>> import tensorflow as tf - >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend(session=tf.Session())) """ global tensorlib global optimizer @@ -54,5 +55,6 @@ def set_backend(backend, custom_optimizer = None): if custom_optimizer: optimizer = custom_optimizer + from .pdf import Model __all__ = ["Model", "utils", "modifiers"] diff --git a/pyhf/commandline.py b/pyhf/commandline.py index 24a25b3af2..beeaff70fc 100644 --- a/pyhf/commandline.py +++ b/pyhf/commandline.py @@ -6,6 +6,7 @@ import json import os import jsonpatch +import sys from . import readxml from . import writexml @@ -28,36 +29,39 @@ def xml2json(entrypoint_xml, basedir, output_file, track_progress): if output_file is None: print(json.dumps(spec, indent=4, sort_keys=True)) else: - json.dump(spec, open(output_file, 'w+'), indent=4, sort_keys=True) + with open(output_file, 'w+') as out_file: + json.dump(spec, out_file, indent=4, sort_keys=True) log.debug("Written to {0:s}".format(output_file)) + sys.exit(0) @pyhf.command() -@click.argument('workspace', default = '-') -@click.argument('xmlfile', default = '-') -@click.option('--specroot', default = click.Path(exists = True)) -@click.option('--dataroot', default = click.Path(exists = True)) -def json2xml(workspace,xmlfile,specroot,dataroot): - specstream = click.open_file(workspace) - outstream = click.open_file(xmlfile,'w') - d = json.load(specstream) - - outstream.write(writexml.writexml(d,specroot,dataroot,'').decode('utf-8')) +@click.argument('workspace', default='-') +@click.argument('xmlfile', default='-') +@click.option('--specroot', default=click.Path(exists=True)) +@click.option('--dataroot', default=click.Path(exists=True)) +def json2xml(workspace, xmlfile, specroot, dataroot): + with click.open_file(workspace, 'r') as specstream: + d = json.load(specstream) + with click.open_file(xmlfile, 'w') as outstream: + outstream.write(writexml.writexml(d, specroot, dataroot,'').decode('utf-8')) + sys.exit(0) @pyhf.command() -@click.argument('workspace', default = '-') +@click.argument('workspace', default='-') @click.option('--output-file', help='The location of the output json file. If not specified, prints to screen.', default=None) @click.option('--measurement', default=None) -@click.option('-p','--patch', multiple = True) +@click.option('-p', '--patch', multiple=True) @click.option('--qualify-names/--no-qualify-names', default=False) def cls(workspace, output_file, measurement, qualify_names, patch): - specstream = click.open_file(workspace) - d = json.load(specstream) + with click.open_file(workspace, 'r') as specstream: + d = json.load(specstream) measurements = d['toplvl']['measurements'] measurement_names = [m['name'] for m in measurements] measurement_index = 0 log.debug('measurements defined:\n\t{0:s}'.format('\n\t'.join(measurement_names))) if measurement and measurement not in measurement_names: log.error('no measurement by name \'{0:s}\' exists, pick from one of the valid ones above'.format(measurement)) + sys.exit(1) else: if not measurement and len(measurements) > 1: log.warning('multiple measurements defined. Taking the first measurement.') @@ -68,7 +72,8 @@ def cls(workspace, output_file, measurement, qualify_names, patch): log.debug('calculating CLs for measurement {0:s}'.format(measurements[measurement_index]['name'])) spec = {'channels':d['channels']} for p in patch: - p = jsonpatch.JsonPatch(json.loads(click.open_file(p).read())) + with click.open_file(p, 'r') as read_file: + p = jsonpatch.JsonPatch(json.loads(read_file.read())) spec = p.apply(spec) p = Model(spec, poiname=measurements[measurement_index]['config']['poi'], qualify_names=qualify_names) result = runOnePoint(1.0, sum((d['data'][c['name']] for c in d['channels']),[]) + p.config.auxdata, p) @@ -76,5 +81,7 @@ def cls(workspace, output_file, measurement, qualify_names, patch): if output_file is None: print(json.dumps(result, indent=4, sort_keys=True)) else: - json.dump(result, open(output_file, 'w+'), indent=4, sort_keys=True) + with open(output_file, 'w+') as out_file: + json.dump(result, out_file, indent=4, sort_keys=True) log.debug("Written to {0:s}".format(output_file)) + sys.exit(0) diff --git a/pyhf/exceptions/__init__.py b/pyhf/exceptions/__init__.py index d0ecb88c71..e42dbddaea 100644 --- a/pyhf/exceptions/__init__.py +++ b/pyhf/exceptions/__init__.py @@ -1,5 +1,10 @@ import sys +class InvalidMeasurement(Exception): + """ + InvalidMeasurement is raised when a specified measurement is invalid given the specification. + """ + class InvalidNameReuse(Exception): pass diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index e2b4fcf9b7..95699ec600 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -340,35 +340,60 @@ def einsum(self, subscripts, *operands): raise NotImplementedError("mxnet::einsum is not implemented.") return self.astensor([]) - def poisson(self, n, lam): - """ - The continous approximation to the probability density function of the Poisson - distribution given the parameters evaluated at `n`. + r""" + The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`, + to the probability mass function of the Poisson distribution evaluated + at :code:`n` given the parameter :code:`lam`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.mxnet_backend()) + >>> pyhf.tensorlib.poisson(5., 6.) + + [0.16062315] + Args: - n (Number or Tensor): The value at which to evaluate the Poisson distribution p.d.f. + n (Number or Tensor): The value at which to evaluate the approximation to the Poisson distribution p.m.f. (the observed number of events) lam (Number or Tensor): The mean of the Poisson distribution p.d.f. (the expected number of events) Returns: - MXNet NDArray: Value of N(n|lam, sqrt(lam)), the continous approximation to Poisson(n|lam). + MXNet NDArray: Value of the continous approximation to Poisson(n|lam) """ - return self.normal(n, lam, self.sqrt(lam)) + n = self.astensor(n) + lam = self.astensor(lam) + + # This is currently copied directly from PyTorch's source until a better + # way can be found to do this in MXNet + # https://github.com/pytorch/pytorch/blob/39520ffec15ab7e97691fed048de1832e83785e8/torch/distributions/poisson.py#L59-L63 + return nd.exp((nd.log(lam) * n) - lam - nd.gammaln(n + 1.)) def normal(self, x, mu, sigma): - """ - The probability density function of the Normal distribution given the parameters - evaluated at `x`. + r""" + The probability density function of the Normal distribution evaluated + at :code:`x` given parameters of mean of :code:`mu` and standard deviation + of :code:`sigma`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.mxnet_backend()) + >>> pyhf.tensorlib.normal(0.5, 0., 1.) + + [0.35206532] + Args: - x (Number or Tensor): The point at which to evaluate the Normal distribution p.d.f. - mu (Number or Tensor): The mean of the Normal distribution p.d.f. - sigma(Number or Tensor): The standard deviation of the Normal distribution p.d.f. + x (Number or Tensor): The value at which to evaluate the Normal distribution p.d.f. + mu (Number or Tensor): The mean of the Normal distribution + sigma (Number or Tensor): The standard deviation of the Normal distribution Returns: - MXNet NDArray: Value of N(x|mu, sigma). + MXNet NDArray: Value of Normal(x|mu, sigma). """ x = self.astensor(x) mu = self.astensor(mu) @@ -376,7 +401,7 @@ def normal(self, x, mu, sigma): # This is currently copied directly from PyTorch's source until a better # way can be found to do this in MXNet - # https://github.com/pytorch/pytorch/blob/master/torch/distributions/normal.py#L61-L66 + # https://github.com/pytorch/pytorch/blob/39520ffec15ab7e97691fed048de1832e83785e8/torch/distributions/normal.py#L70-L76 def log_prob(value, loc, scale): variance = scale ** 2 log_scale = math.log(scale) if isinstance( diff --git a/pyhf/tensor/numpy_backend.py b/pyhf/tensor/numpy_backend.py index 809c864f36..ce4ada9325 100644 --- a/pyhf/tensor/numpy_backend.py +++ b/pyhf/tensor/numpy_backend.py @@ -1,12 +1,15 @@ import numpy as np import logging -from scipy.special import gammaln, xlogy +from scipy.special import gammaln from scipy.stats import norm log = logging.getLogger(__name__) + class numpy_backend(object): + """NumPy backend for pyhf""" + def __init__(self, **kwargs): - self.pois_from_norm = kwargs.get('poisson_from_normal',False) + pass def clip(self, tensor_in, min, max): """ @@ -157,12 +160,52 @@ def einsum(self, subscripts, *operands): return np.einsum(subscripts, *operands) def poisson(self, n, lam): + r""" + The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`, + to the probability mass function of the Poisson distribution evaluated + at :code:`n` given the parameter :code:`lam`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.numpy_backend()) + >>> pyhf.tensorlib.poisson(5., 6.) + 0.16062314104797995 + + Args: + n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f. + (the observed number of events) + lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f. + (the expected number of events) + + Returns: + NumPy float: Value of the continous approximation to Poisson(n|lam) + """ n = np.asarray(n) - if self.pois_from_norm: - return self.normal(n,lam, self.sqrt(lam)) - return np.exp(xlogy(n, lam) - lam - gammaln(n + 1.)) + lam = np.asarray(lam) + return np.exp(n * np.log(lam) - lam - gammaln(n + 1.)) def normal(self, x, mu, sigma): + r""" + The probability density function of the Normal distribution evaluated + at :code:`x` given parameters of mean of :code:`mu` and standard deviation + of :code:`sigma`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.numpy_backend()) + >>> pyhf.tensorlib.normal(0.5, 0., 1.) + 0.3520653267642995 + + Args: + x (`tensor` or `float`): The value at which to evaluate the Normal distribution p.d.f. + mu (`tensor` or `float`): The mean of the Normal distribution + sigma (`tensor` or `float`): The standard deviation of the Normal distribution + + Returns: + NumPy float: Value of Normal(x|mu, sigma) + """ return norm.pdf(x, loc=mu, scale=sigma) def normal_cdf(self, x, mu=0, sigma=1): diff --git a/pyhf/tensor/pytorch_backend.py b/pyhf/tensor/pytorch_backend.py index 0a3df608d3..f4c6fd564d 100644 --- a/pyhf/tensor/pytorch_backend.py +++ b/pyhf/tensor/pytorch_backend.py @@ -3,7 +3,10 @@ import logging log = logging.getLogger(__name__) + class pytorch_backend(object): + """PyTorch backend for pyhf""" + def __init__(self, **kwargs): pass @@ -167,11 +170,53 @@ def einsum(self, subscripts, *operands): ops = tuple(self.astensor(op) for op in operands) return torch.einsum(subscripts, ops) - def poisson(self, n, lam): - return self.normal(n,lam, self.sqrt(lam)) + r""" + The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`, + to the probability mass function of the Poisson distribution evaluated + at :code:`n` given the parameter :code:`lam`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.pytorch_backend()) + >>> pyhf.tensorlib.poisson(5., 6.) + tensor([0.1606]) + + Args: + n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f. + (the observed number of events) + lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f. + (the expected number of events) + + Returns: + PyTorch FloatTensor: Value of the continous approximation to Poisson(n|lam) + """ + n = self.astensor(n) + lam = self.astensor(lam) + return torch.exp(torch.distributions.Poisson(lam).log_prob(n)) def normal(self, x, mu, sigma): + r""" + The probability density function of the Normal distribution evaluated + at :code:`x` given parameters of mean of :code:`mu` and standard deviation + of :code:`sigma`. + + Example: + + >>> import pyhf + >>> pyhf.set_backend(pyhf.tensor.pytorch_backend()) + >>> pyhf.tensorlib.normal(0.5, 0., 1.) + tensor([0.3521]) + + Args: + x (`tensor` or `float`): The value at which to evaluate the Normal distribution p.d.f. + mu (`tensor` or `float`): The mean of the Normal distribution + sigma (`tensor` or `float`): The standard deviation of the Normal distribution + + Returns: + PyTorch FloatTensor: Value of Normal(x|mu, sigma) + """ x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) diff --git a/pyhf/tensor/tensorflow_backend.py b/pyhf/tensor/tensorflow_backend.py index baeef1ecb1..19c24a23ec 100644 --- a/pyhf/tensor/tensorflow_backend.py +++ b/pyhf/tensor/tensorflow_backend.py @@ -1,10 +1,13 @@ import logging import tensorflow as tf +# import tensorflow_probability as tfp log = logging.getLogger(__name__) class tensorflow_backend(object): + """TensorFlow backend for pyhf""" + def __init__(self, **kwargs): self.session = kwargs.get('session') @@ -18,10 +21,10 @@ def clip(self, tensor_in, min, max): >>> import tensorflow as tf >>> sess = tf.Session() ... - >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend()) + >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend(session=sess)) >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2]) >>> with sess.as_default(): - ... pyhf.tensorlib.clip(a, -1, 1).eval() + ... sess.run(pyhf.tensorlib.clip(a, -1, 1)) ... array([-1., -1., 0., 1., 1.], dtype=float32) @@ -191,18 +194,72 @@ def einsum(self, subscripts, *operands): operands: list of array_like, these are the tensors for the operation Returns: - tensor: the calculation based on the Einstein summation convention + TensorFlow Tensor: the calculation based on the Einstein summation convention """ return tf.einsum(subscripts, *operands) def poisson(self, n, lam): - # could be changed to actual Poisson easily - return self.normal(n, lam, self.sqrt(lam)) + r""" + The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`, + to the probability mass function of the Poisson distribution evaluated + at :code:`n` given the parameter :code:`lam`. + + Example: + + >>> import pyhf + >>> import tensorflow as tf + >>> sess = tf.Session() + >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend(session=sess)) + ... + >>> with sess.as_default(): + ... sess.run(pyhf.tensorlib.poisson(5., 6.)) + ... + array([0.16062315], dtype=float32) + + Args: + n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f. + (the observed number of events) + lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f. + (the expected number of events) + + Returns: + TensorFlow Tensor: Value of the continous approximation to Poisson(n|lam) + """ + n = self.astensor(n) + lam = self.astensor(lam) + # return tf.exp(tfp.distributions.Poisson(lam).log_prob(n)) + return tf.exp(tf.contrib.distributions.Poisson(lam).log_prob(n)) def normal(self, x, mu, sigma): + r""" + The probability density function of the Normal distribution evaluated + at :code:`x` given parameters of mean of :code:`mu` and standard deviation + of :code:`sigma`. + + Example: + + >>> import pyhf + >>> import tensorflow as tf + >>> sess = tf.Session() + >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend(session=sess)) + ... + >>> with sess.as_default(): + ... sess.run(pyhf.tensorlib.normal(0.5, 0., 1.)) + ... + array([0.35206532], dtype=float32) + + Args: + x (`tensor` or `float`): The value at which to evaluate the Normal distribution p.d.f. + mu (`tensor` or `float`): The mean of the Normal distribution + sigma (`tensor` or `float`): The standard deviation of the Normal distribution + + Returns: + TensorFlow Tensor: Value of Normal(x|mu, sigma) + """ x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) + # normal = tfp.distributions.Normal(mu, sigma) normal = tf.distributions.Normal(mu, sigma) return normal.prob(x) @@ -218,7 +275,7 @@ def normal_cdf(self, x, mu=0, sigma=1): ... >>> pyhf.set_backend(pyhf.tensor.tensorflow_backend()) >>> with sess.as_default(): - ... pyhf.tensorlib.normal_cdf(0.8).eval() + ... sess.run(pyhf.tensorlib.normal_cdf(0.8)) ... array([0.7881446], dtype=float32) @@ -233,5 +290,6 @@ def normal_cdf(self, x, mu=0, sigma=1): x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) + # normal = tfp.distributions.Normal(mu, sigma) normal = tf.distributions.Normal(mu, sigma) return normal.cdf(x) diff --git a/pyhf/utils.py b/pyhf/utils.py index 78e5e1e23c..150e3611f4 100644 --- a/pyhf/utils.py +++ b/pyhf/utils.py @@ -1,9 +1,11 @@ -import json, jsonschema +import json +import jsonschema import pkg_resources from .exceptions import InvalidSpecification from . import get_backend + def get_default_schema(): r""" Returns the absolute filepath default schema for pyhf. This usually points @@ -13,10 +15,12 @@ def get_default_schema(): Schema File Path: a string containing the absolute path to the default schema file. """ - return pkg_resources.resource_filename(__name__,'data/spec.json') + return pkg_resources.resource_filename(__name__, 'data/spec.json') SCHEMA_CACHE = {} + + def load_schema(schema): global SCHEMA_CACHE try: @@ -24,7 +28,8 @@ def load_schema(schema): except KeyError: pass - SCHEMA_CACHE[schema] = json.load(open(schema)) + with open(schema) as json_schema: + SCHEMA_CACHE[schema] = json.load(json_schema) return SCHEMA_CACHE[schema] @@ -35,9 +40,11 @@ def validate(spec, schema): except jsonschema.ValidationError as err: raise InvalidSpecification(err) + def loglambdav(pars, data, pdf): return -2 * pdf.logpdf(pars, data) + def qmu(mu, data, pdf, init_pars, par_bounds): r""" The test statistic, :math:`q_{\mu}`, for establishing an upper @@ -76,12 +83,14 @@ def qmu(mu, data, pdf, init_pars, par_bounds): qmu = tensorlib.where(muhatbhat[pdf.config.poi_index] > mu, [0], qmu) return qmu + def generate_asimov_data(asimov_mu, data, pdf, init_pars, par_bounds): _, optimizer = get_backend() bestfit_nuisance_asimov = optimizer.constrained_bestfit( loglambdav, asimov_mu, data, pdf, init_pars, par_bounds) return pdf.expected_data(bestfit_nuisance_asimov) + def pvals_from_teststat(sqrtqmu_v, sqrtqmuA_v): r""" The :math:`p`-values for signal strength :math:`\mu` and Asimov strength :math:`\mu'` @@ -114,7 +123,8 @@ def pvals_from_teststat(sqrtqmu_v, sqrtqmuA_v): CLs = CLsb / CLb return CLsb, CLb, CLs -def runOnePoint(muTest, data, pdf, init_pars = None, par_bounds = None): + +def runOnePoint(muTest, data, pdf, init_pars=None, par_bounds=None): r""" Computes test statistics (and expected statistics) for a single value of the parameter of interest diff --git a/setup.py b/setup.py index 77098527c5..1e9f7d65ae 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ ], 'tensorflow':[ 'tensorflow>=1.10.0', + # 'tensorflow-probability>=0.3.0', # Causing troulbe with Travis CI, but *should* be used 'numpy<=1.14.5,>=1.14.0', # Lower of 1.14.0 instead of 1.13.3 to ensure doctest pass 'setuptools<=39.1.0', ], diff --git a/tests/benchmarks/test_benchmark.py b/tests/benchmarks/test_benchmark.py index ec31823f9f..8178b4d518 100644 --- a/tests/benchmarks/test_benchmark.py +++ b/tests/benchmarks/test_benchmark.py @@ -1,6 +1,5 @@ import pyhf from pyhf.simplemodels import hepdata_like -import tensorflow as tf import numpy as np import pytest @@ -64,11 +63,6 @@ def generate_source_poisson(n_bins): def runOnePoint(pdf, data): - if isinstance(pyhf.tensorlib, pyhf.tensor.tensorflow_backend): - # Reset the TensorFlow graph and session for each run - tf.reset_default_graph() - pyhf.tensorlib.session = tf.Session() - return pyhf.utils.runOnePoint(1.0, data, pdf, pdf.config.suggested_init(), pdf.config.suggested_bounds()) @@ -80,19 +74,7 @@ def runOnePoint(pdf, data): @pytest.mark.parametrize('n_bins', bins, ids=bin_ids) -@pytest.mark.parametrize('backend', - [ - pyhf.tensor.numpy_backend(poisson_from_normal=True), - pyhf.tensor.tensorflow_backend(session=tf.Session()), - pyhf.tensor.pytorch_backend(), - # pyhf.tensor.mxnet_backend(), - ], - ids=[ - 'numpy', - 'tensorflow', - 'pytorch', - # 'mxnet', - ]) +@pytest.mark.skip_mxnet def test_runOnePoint(benchmark, backend, n_bins): """ Benchmark the performance of pyhf.runOnePoint() @@ -106,8 +88,6 @@ def test_runOnePoint(benchmark, backend, n_bins): Returns: None """ - pyhf.set_backend(backend) - source = generate_source_static(n_bins) pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], diff --git a/tests/test_backend_consistency.py b/tests/test_backend_consistency.py index 3107546056..a9005bd998 100644 --- a/tests/test_backend_consistency.py +++ b/tests/test_backend_consistency.py @@ -94,7 +94,7 @@ def test_runOnePoint_q_mu(n_bins, data = source['bindata']['data'] + pdf.config.auxdata backends = [ - pyhf.tensor.numpy_backend(poisson_from_normal=True), + pyhf.tensor.numpy_backend(), pyhf.tensor.tensorflow_backend(session=tf.Session()), pyhf.tensor.pytorch_backend(), # mxnet_backend() diff --git a/tests/test_import.py b/tests/test_import.py index 9a820cc4e7..71095e7342 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -4,9 +4,10 @@ import pytest import numpy as np + def test_import_prepHistFactory(): parsed_xml = pyhf.readxml.parse('validation/xmlimport_input/config/example.xml', - 'validation/xmlimport_input/') + 'validation/xmlimport_input/') # build the spec, strictly checks properties included spec = {'channels': parsed_xml['channels']} @@ -16,14 +17,14 @@ def test_import_prepHistFactory(): in parsed_xml['data'][k['name']]] + pdf.config.auxdata channels = {channel['name'] for channel in pdf.spec['channels']} - samples = {channel['name']: [sample['name'] for sample in channel['samples']] for channel in pdf.spec['channels']} - + samples = {channel['name']: [sample['name'] + for sample in channel['samples']] for channel in pdf.spec['channels']} ### - ### signal overallsys - ### bkg1 overallsys (stat ignored) - ### bkg2 stateror (2 bins) - ### bkg2 overallsys + # signal overallsys + # bkg1 overallsys (stat ignored) + # bkg2 stateror (2 bins) + # bkg2 overallsys assert 'channel1' in channels assert 'signal' in samples['channel1'] @@ -31,15 +32,16 @@ def test_import_prepHistFactory(): assert 'background2' in samples['channel1'] assert pdf.spec['channels'][0]['samples'][2]['modifiers'][0]['type'] == 'staterror' - assert pdf.spec['channels'][0]['samples'][2]['modifiers'][0]['data'] == [0,10.] + assert pdf.spec['channels'][0]['samples'][2]['modifiers'][0]['data'] == [0, 10.] assert pdf.spec['channels'][0]['samples'][1]['modifiers'][0]['type'] == 'staterror' - assert all(np.isclose(pdf.spec['channels'][0]['samples'][1]['modifiers'][0]['data'],[5.0, 0.0])) + assert all(np.isclose( + pdf.spec['channels'][0]['samples'][1]['modifiers'][0]['data'], [5.0, 0.0])) assert pdf.expected_actualdata( pdf.config.suggested_init()).tolist() == [120.0, 110.0] - assert pdf.config.auxdata_order == ['syst1', 'staterror_channel1', 'syst2', 'syst3'] + assert pdf.config.auxdata_order == ['syst1', 'staterror_channel1', 'syst2', 'syst3'] assert data == [122.0, 112.0, 0.0, 1.0, 1.0, 0.0, 0.0] @@ -48,9 +50,10 @@ def test_import_prepHistFactory(): assert pdf.expected_data( pars, include_auxdata=False).tolist() == [140, 120] + def test_import_histosys(): parsed_xml = pyhf.readxml.parse('validation/xmlimport_input2/config/example.xml', - 'validation/xmlimport_input2') + 'validation/xmlimport_input2') # build the spec, strictly checks properties included spec = {'channels': parsed_xml['channels']} @@ -59,7 +62,8 @@ def test_import_histosys(): data = [binvalue for k in pdf.spec['channels'] for binvalue in parsed_xml['data'][k['name']]] + pdf.config.auxdata - channels = {channel['name']:channel for channel in pdf.spec['channels']} - samples = {channel['name']: [sample['name'] for sample in channel['samples']] for channel in pdf.spec['channels']} + channels = {channel['name']: channel for channel in pdf.spec['channels']} + samples = {channel['name']: [sample['name'] + for sample in channel['samples']] for channel in pdf.spec['channels']} assert channels['channel2']['samples'][0]['modifiers'][0]['type'] == 'histosys' diff --git a/tests/test_pdf.py b/tests/test_pdf.py index e38b35591d..37f3caeb16 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -16,7 +16,7 @@ def test_numpy_pdf_inputs(backend): "sig": [10.0] } } - pdf = pyhf.simplemodels.hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr']) + pdf = pyhf.simplemodels.hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr']) pars = pdf.config.suggested_init() data = source['bindata']['data'] + pdf.config.auxdata @@ -243,6 +243,7 @@ def test_pdf_integration_shapesys(backend): pars[pdf.config.par_slice('mu')], pars[pdf.config.par_slice('bkg_norm')] = [[0.0], [0.9,1.1]] assert pdf.expected_data(pars, include_auxdata = False).tolist() == [100*0.9,150*1.1] + def test_invalid_modifier(): spec = { 'channels': [ @@ -263,6 +264,7 @@ def test_invalid_modifier(): with pytest.raises(pyhf.exceptions.InvalidModifier): pyhf.pdf._ModelConfig.from_spec(spec) + def test_invalid_modifier_name_resuse(): spec = { 'channels': [ @@ -288,7 +290,6 @@ def test_invalid_modifier_name_resuse(): ] } with pytest.raises(pyhf.exceptions.InvalidNameReuse): - pdf = pyhf.Model(spec, poiname = 'reused_name') - - pdf = pyhf.Model(spec, poiname = 'reused_name', qualify_names = True) + pdf = pyhf.Model(spec, poiname = 'reused_name') + pdf = pyhf.Model(spec, poiname = 'reused_name', qualify_names = True) diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 8b97ebd376..105bca3ed5 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -60,6 +60,14 @@ def test_import_prepHistFactory_and_cls(tmpdir, script_runner): assert 'CLs_obs' in d assert 'CLs_exp' in d + tmp_out = tmpdir.join('{0:s}_output.json'.format(measurement)) + # make sure output file works too + command += ' --output-file {0:s}'.format(tmp_out.strpath) + ret = script_runner.run(*shlex.split(command)) + assert ret.success + d = json.load(tmp_out) + assert 'CLs_obs' in d + assert 'CLs_exp' in d def test_import_and_export(tmpdir, script_runner): temp = tmpdir.join("parsed_output.json") @@ -88,7 +96,7 @@ def test_patch(tmpdir, script_runner): import io command = 'pyhf cls {0:s} --patch -'.format(temp.strpath,patch.strpath) - pipefile = io.StringIO(patchcontent) #python 2.7 pytest-files are not file-like enough + pipefile = io.StringIO(patchcontent) # python 2.7 pytest-files are not file-like enough ret = script_runner.run(*shlex.split(command), stdin = pipefile) print(ret.stderr) assert ret.success @@ -107,3 +115,12 @@ def test_patch_fail(tmpdir, script_runner): ret = script_runner.run(*shlex.split(command)) assert not ret.success +def test_bad_measurement_name(tmpdir, script_runner): + temp = tmpdir.join("parsed_output.json") + command = 'pyhf xml2json validation/xmlimport_input/config/example.xml --basedir validation/xmlimport_input/ --output-file {0:s}'.format(temp.strpath) + ret = script_runner.run(*shlex.split(command)) + + command = 'pyhf cls {0:s} --measurement "a-fake-measurement-name"'.format(temp.strpath) + ret = script_runner.run(*shlex.split(command)) + assert not ret.success + #assert 'no measurement by name' in ret.stderr # numpy swallows the log.error() here, dunno why diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 7842c9ac1d..cf5b2d261c 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,8 +1,8 @@ import pyhf -from pyhf.tensor.pytorch_backend import pytorch_backend from pyhf.tensor.numpy_backend import numpy_backend from pyhf.tensor.tensorflow_backend import tensorflow_backend +from pyhf.tensor.pytorch_backend import pytorch_backend from pyhf.tensor.mxnet_backend import mxnet_backend from pyhf.simplemodels import hepdata_like @@ -46,12 +46,14 @@ def test_common_tensor_backends(backend): == [[1, 1, 1], [2, 3, 4], [5, 6, 7]] assert list(map(tb.tolist, tb.simple_broadcast([1], [2, 3, 4], [5, 6, 7]))) \ == [[1, 1, 1], [2, 3, 4], [5, 6, 7]] - assert tb.tolist(tb.ones((4,5))) == [[1.]*5]*4 - assert tb.tolist(tb.zeros((4,5))) == [[0.]*5]*4 - assert tb.tolist(tb.abs([-1,-2])) == [1,2] + assert tb.tolist(tb.ones((4, 5))) == [[1.] * 5] * 4 + assert tb.tolist(tb.zeros((4, 5))) == [[0.] * 5] * 4 + assert tb.tolist(tb.abs([-1, -2])) == [1, 2] with pytest.raises(Exception): tb.simple_broadcast([1], [2, 3], [5, 6, 7]) + # poisson(lambda=0) is not defined, should return NaN + assert tb.tolist(pyhf.tensorlib.poisson([0, 0, 1, 1], [0, 1, 0, 1])) == pytest.approx([np.nan, 0.3678794503211975, 0.0, 0.3678794503211975], nan_ok=True) def test_einsum(backend): tb = pyhf.tensorlib @@ -64,13 +66,15 @@ def test_einsum(backend): assert np.all(tb.tolist(tb.einsum('ij->ji',x)) == np.asarray(x).T.tolist()) assert tb.tolist(tb.einsum('i,j->ij',tb.astensor([1,1,1]),tb.astensor([1,2,3]))) == [[1,2,3]]*3 + def test_pdf_eval(): tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal=True), - pytorch_backend(), - tensorflow_backend(session=tf_sess), - mxnet_backend() #no einsum in mxnet - ] + backends = [ + numpy_backend(), + tensorflow_backend(session=tf_sess), + pytorch_backend(), + mxnet_backend() + ] values = [] for b in backends: @@ -114,15 +118,17 @@ def test_pdf_eval(): v1 = pdf.logpdf(pdf.config.suggested_init(), data) values.append(pyhf.tensorlib.tolist(v1)[0]) - assert np.std(values) < 1e-6 + assert np.std(values) < 5e-5 def test_pdf_eval_2(): tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal=True), - pytorch_backend(), - tensorflow_backend(session=tf_sess), - mxnet_backend()] + backends = [ + numpy_backend(), + tensorflow_backend(session=tf_sess), + pytorch_backend(), + mxnet_backend() + ] values = [] for b in backends: @@ -145,4 +151,4 @@ def test_pdf_eval_2(): v1 = pdf.logpdf(pdf.config.suggested_init(), data) values.append(pyhf.tensorlib.tolist(v1)[0]) - assert np.std(values) < 1e-6 + assert np.std(values) < 5e-5 diff --git a/tests/test_validation.py b/tests/test_validation.py index 566ee79fa6..e4798bd39d 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -2,9 +2,11 @@ import json import pytest + @pytest.fixture(scope='module') def source_1bin_example1(): - return json.load(open('validation/data/1bin_example1.json')) + with open('validation/data/1bin_example1.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -160,7 +162,8 @@ def setup_1bin_normsys(source=source_1bin_normsys(), @pytest.fixture(scope='module') def source_2bin_histosys_example2(): - return json.load(open('validation/data/2bin_histosys_example2.json')) + with open('validation/data/2bin_histosys_example2.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -238,7 +241,8 @@ def setup_2bin_histosys(source=source_2bin_histosys_example2(), @pytest.fixture(scope='module') def source_2bin_2channel_example1(): - return json.load(open('validation/data/2bin_2channel_example1.json')) + with open('validation/data/2bin_2channel_example1.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -330,7 +334,8 @@ def setup_2bin_2channel(source=source_2bin_2channel_example1(), @pytest.fixture(scope='module') def source_2bin_2channel_couplednorm(): - return json.load(open('validation/data/2bin_2channel_couplednorm.json')) + with open('validation/data/2bin_2channel_couplednorm.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -434,7 +439,8 @@ def setup_2bin_2channel_couplednorm( @pytest.fixture(scope='module') def source_2bin_2channel_coupledhisto(): - return json.load(open('validation/data/2bin_2channel_coupledhisto.json')) + with open('validation/data/2bin_2channel_coupledhisto.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -547,7 +553,8 @@ def setup_2bin_2channel_coupledhistosys( @pytest.fixture(scope='module') def source_2bin_2channel_coupledshapefactor(): - return json.load(open('validation/data/2bin_2channel_coupledshapefactor.json')) + with open('validation/data/2bin_2channel_coupledshapefactor.json') as read_json: + return json.load(read_json) @pytest.fixture(scope='module') @@ -690,4 +697,3 @@ def test_validation(setup): setup['expected']['config']['par_bounds'] validate_runOnePoint(pdf, data, setup['mu'], setup['expected']['result']) -