Skip to content

Commit

Permalink
make linear response more robust and fix bug with predictions (#5080)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Oct 15, 2021
1 parent ce447cc commit e03f5bf
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 30 deletions.
19 changes: 16 additions & 3 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from pandas import DataFrame, Series

from pymc.distributions.distribution import NoDistribution

Expand Down Expand Up @@ -93,8 +94,8 @@ class BART(NoDistribution):
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
response : str
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
``mix`` (default).
How the leaf_node values are computed. Available options are ``constant`` (default),
``linear`` or ``mix``.
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Expand All @@ -109,12 +110,13 @@ def __new__(
m=50,
alpha=0.25,
k=2,
response="mix",
response="constant",
split_prior=None,
**kwargs,
):

cls.all_trees = []
X, Y = preprocess_XY(X, Y)

bart_op = type(
f"BART_{name}",
Expand Down Expand Up @@ -143,3 +145,14 @@ def __new__(
@classmethod
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)


def preprocess_XY(X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()
# X = np.random.normal(X, X.std(0)/100)
Y = Y.astype(float)
X = X.astype(float)
return X, Y
39 changes: 18 additions & 21 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np

from aesara import function as aesara_function
from pandas import DataFrame, Series

from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc.bart.bart import BARTRV
Expand Down Expand Up @@ -127,11 +126,13 @@ class PGBART(ArrayStepShared):
def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
model = modelcontext(model)
initial_values = model.initial_point
initial_values = model.recompute_initial_point()
value_bart = inputvars(vars)[0]
self.bart = model.values_to_rvs[value_bart].owner.op

self.X, self.Y, self.missing_data = preprocess_XY(self.bart.X, self.bart.Y)
self.X = self.bart.X
self.Y = self.bart.Y
self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.alpha = self.bart.alpha
self.k = self.bart.k
Expand Down Expand Up @@ -342,16 +343,6 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
particle.old_likelihood_logp = new_likelihood


def preprocess_XY(X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()
missing_data = np.any(np.isnan(X))
Y = Y.astype(float)
return X, Y, missing_data


class SampleSplittingVariable:
def __init__(self, alpha_prior):
"""
Expand Down Expand Up @@ -493,16 +484,19 @@ def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, respon
linear_params = None
if Y_mu_pred.size == 0:
return 0, linear_params
elif Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
else:
if response == "constant":
norm = normal.random() * mu_std
if Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
elif response == "constant":
mu_mean = mean(Y_mu_pred) / m
elif response == "linear":
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
mu_mean = Y_fit / m
draw = normal.random() * mu_std + mu_mean
return draw, linear_params
linear_params[2] = norm

draw = norm + mu_mean
return draw, linear_params


def fast_mean():
Expand Down Expand Up @@ -532,11 +526,14 @@ def linear_fit(X, Y):
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
a = ybar - b * xbar
if np.all(X == xbar):
b = 0
else:
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)

a = ybar - b * xbar
Y_fit = a + b * X
return Y_fit, (a, b)
return Y_fit, [a, b, 0]

try:
from numba import jit
Expand Down
11 changes: 6 additions & 5 deletions pymc/bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ def predict_out_of_sample(self, X, m):
Value of the leaf value where the unobserved point lies.
"""
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
if leaf_node.linear_params is None:
linear_params = leaf_node.linear_params
if linear_params is None:
return leaf_node.value
else:
x = X[split_variable].item()
y_x = leaf_node.linear_params[0] + leaf_node.linear_params[1] * x
return y_x / m
y_x = (linear_params[0] + linear_params[1] * x) / m
return y_x + linear_params[2]

def _traverse_tree(self, x, node_index=0, split_variable=None):
"""
Expand All @@ -136,10 +137,10 @@ def _traverse_tree(self, x, node_index=0, split_variable=None):
split_variable = current_node.idx_split_variable
if x[split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
current_node, _ = self._traverse_tree(x, left_child, split_variable)
current_node, split_variable = self._traverse_tree(x, left_child, split_variable)
else:
right_child = current_node.get_idx_right_child()
current_node, _ = self._traverse_tree(x, right_child, split_variable)
current_node, split_variable = self._traverse_tree(x, right_child, split_variable)
return current_node, split_variable

def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
Expand Down
3 changes: 2 additions & 1 deletion pymc/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_bart_vi():
)
var_imp /= var_imp.sum()
assert var_imp[0] > var_imp[1:].sum()
np.testing.assert_almost_equal(var_imp.sum(), 1)
assert_almost_equal(var_imp.sum(), 1)


def test_bart_random():
Expand All @@ -62,6 +62,7 @@ def test_bart_random():
rng = RandomState(12345)
pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10])

assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
assert pred_all.shape == (2, 50)
assert pred_first.shape == (10,)

Expand Down

0 comments on commit e03f5bf

Please sign in to comment.