diff --git a/econml/inference/_inference.py b/econml/inference/_inference.py index 5c59103de..1fff3e906 100644 --- a/econml/inference/_inference.py +++ b/econml/inference/_inference.py @@ -1179,8 +1179,8 @@ class PopulationSummaryResults: """ - def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha, value, decimals, tol, - output_names=None, treatment_names=None): + def __init__(self, pred, pred_stderr, mean_pred_stderr, d_t, d_y, alpha=0.1, + value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None): self.pred = pred self.pred_stderr = pred_stderr self.mean_pred_stderr = mean_pred_stderr @@ -1237,13 +1237,13 @@ def stderr_mean(self): raise AttributeError("Only point estimates are available!") return np.sqrt(np.mean(self.pred_stderr**2, axis=0)) - def zstat(self, *, value=0): + def zstat(self, *, value=None): """ Get the z statistic of the mean point estimate of each treatment on each outcome for sample X. Parameters ---------- - value: optinal float (default=0) + value: optional float (default=0) The mean value of the metric you'd like to test under null hypothesis. Returns @@ -1258,13 +1258,13 @@ def zstat(self, *, value=0): zstat = (self.mean_point - value) / self.stderr_mean return zstat - def pvalue(self, *, value=0): + def pvalue(self, *, value=None): """ Get the p value of the z test of each treatment on each outcome for sample X. Parameters ---------- - value: optinal float (default=0) + value: optional float (default=0) The mean value of the metric you'd like to test under null hypothesis. Returns @@ -1275,10 +1275,11 @@ def pvalue(self, *, value=0): the corresponding singleton dimensions in the output will be collapsed (e.g. if both are vectors, then the output of this method will be a scalar) """ + value = self.value if value is None else value pvalue = norm.sf(np.abs(self.zstat(value=value)), loc=0, scale=1) * 2 return pvalue - def conf_int_mean(self, *, alpha=.1): + def conf_int_mean(self, *, alpha=None): """ Get the confidence interval of the mean point estimate of each treatment on each outcome for sample X. @@ -1323,7 +1324,7 @@ def std_point(self): """ return np.std(self.pred, axis=0) - def percentile_point(self, *, alpha=.1): + def percentile_point(self, *, alpha=None): """ Get the confidence interval of the point estimate of each treatment on each outcome for sample X. @@ -1346,7 +1347,7 @@ def percentile_point(self, *, alpha=.1): upper_percentile_point = np.percentile(self.pred, (1 - alpha / 2) * 100, axis=0) return lower_percentile_point, upper_percentile_point - def conf_int_point(self, *, alpha=.1, tol=.001): + def conf_int_point(self, *, alpha=None, tol=None): """ Get the confidence interval of the point estimate of each treatment on each outcome for sample X. @@ -1389,7 +1390,7 @@ def stderr_point(self): """ return np.sqrt(self.stderr_mean**2 + self.std_point**2) - def summary(self, alpha=0.1, value=0, decimals=3, tol=0.001, output_names=None, treatment_names=None): + def summary(self, alpha=None, value=None, decimals=None, tol=None, output_names=None, treatment_names=None): """ Output the summary inferences above. diff --git a/econml/tests/test_inference.py b/econml/tests/test_inference.py index 6aebd7412..d977ec1c1 100644 --- a/econml/tests/test_inference.py +++ b/econml/tests/test_inference.py @@ -288,6 +288,18 @@ def test_can_summarize(self): inference=BootstrapInference(5) ).summary(1) + def test_alpha(self): + Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W + est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression()) + est.fit(Y, T, X=X, W=W) + + # ensure alpha is passed + lb, ub = est.const_marginal_ate_interval(X, alpha=1) + assert (lb == ub).all() + + lb, ub = est.const_marginal_ate_interval(X) + assert (lb != ub).all() + def test_inference_with_none_stderr(self): Y, T, X, W = TestInference.Y, TestInference.T, TestInference.X, TestInference.W est = DML(model_y=LinearRegression(),