Skip to content

Commit

Permalink
no print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored and twiecki committed Jul 26, 2021
1 parent 04cdd96 commit cdc6a39
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ repos:
- id: pylint
args: [--rcfile=.pylintrc]
files: ^pymc3/
- repo: https://github.com/MarcoGorelli/madforhooks
rev: 0.2.1
hooks:
- id: no-print-statements
files: ^pymc3/
- repo: local
hooks:
- id: check-no-tests-are-ignored
Expand Down
2 changes: 1 addition & 1 deletion pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def perform(self, node, inputs, outputs, params=None):
log_det = np.sum(np.log(np.abs(s)))
z[0] = np.asarray(log_det, dtype=x.dtype)
except Exception:
print(f"Failed to compute logdet of {x}.")
print(f"Failed to compute logdet of {x}.", file=sys.stdout)
raise

def grad(self, inputs, g_outputs):
Expand Down
9 changes: 5 additions & 4 deletions pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: skip-file
import os
import re
import sys
import warnings

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
Expand Down Expand Up @@ -206,7 +207,7 @@ def sample_numpyro_nuts(
rv_samples.name = rv.name
sample_outputs.append(rv_samples)

print("Compiling...")
print("Compiling...", file=sys.stdout)

tic1 = pd.Timestamp.now()
_sample = compile_rv_inplace(
Expand All @@ -219,14 +220,14 @@ def sample_numpyro_nuts(
)
tic2 = pd.Timestamp.now()

print("Compilation time = ", tic2 - tic1)
print("Compilation time = ", tic2 - tic1, file=sys.stdout)

print("Sampling...")
print("Sampling...", file=sys.stdout)

*mcmc_samples, leapfrogs_taken = _sample()
tic3 = pd.Timestamp.now()

print("Sampling time = ", tic3 - tic2)
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}

Expand Down
3 changes: 2 additions & 1 deletion pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@author: johnsalvatier
"""
import copy
import sys

import aesara.gradient as tg
import numpy as np
Expand Down Expand Up @@ -153,7 +154,7 @@ def dlogp_func(x):
assert isinstance(cost_func.progress, ProgressBar)
cost_func.progress.total = last_v
cost_func.progress.update(last_v)
print()
print(file=sys.stdout)

mx0 = RaveledVars(mx0, x0.point_map_info)

Expand Down

0 comments on commit cdc6a39

Please sign in to comment.