Skip to content

Commit

Permalink
fix distributional parameter activation in predict
Browse files Browse the repository at this point in the history
  • Loading branch information
thielmaf committed May 29, 2024
1 parent 4d9c8a8 commit 6f55bd3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
10 changes: 7 additions & 3 deletions mambular/models/sklearn_distributional.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def fit(

return self

def predict(self, X):
def predict(self, X, raw=False):
"""
Predicts target values for the given input samples using the fitted model.
Expand Down Expand Up @@ -565,8 +565,12 @@ def predict(self, X):
with torch.no_grad():
predictions = self.model(cat_tensors, num_tensors)

if not raw:
return self.model.family(predictions).cpu().numpy()

# Convert predictions to NumPy array and return
return predictions.cpu().numpy()
else:
return predictions.cpu().numpy()

def evaluate(self, X, y_true, metrics=None, distribution_family=None):
"""
Expand Down Expand Up @@ -600,7 +604,7 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None):
metrics = self.get_default_metrics(distribution_family)

# Make predictions
predictions = self.predict(X)
predictions = self.predict(X, raw=False)

# Initialize dictionary to store results
scores = {}
Expand Down
26 changes: 23 additions & 3 deletions mambular/utils/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


class BaseDistribution:
class BaseDistribution(torch.nn.Module):
"""
The base class for various statistical distributions, providing a common interface and utilities.
Expand All @@ -23,6 +23,8 @@ class BaseDistribution:
"""

def __init__(self, name, param_names):
super(BaseDistribution, self).__init__()

self._name = name
self.param_names = param_names
self.param_count = len(param_names)
Expand Down Expand Up @@ -97,6 +99,24 @@ def evaluate_nll(self, y_true, y_pred):
"NLL": nll_loss_tensor.detach().numpy(),
}

def forward(self, predictions):
"""
Apply the appropriate transformations to the predicted parameters.
Parameters:
predictions (torch.Tensor): The predicted parameters of the distribution.
Returns:
torch.Tensor: A tensor with transformed parameters.
"""
transformed_params = []
for idx, param_name in enumerate(self.param_names):
transform_func = self.get_transform(
getattr(self, f"{param_name}_transform", "none")
)
transformed_params.append(transform_func(predictions[:, idx]).unsqueeze(1))
return torch.cat(transformed_params, dim=1)


class NormalDistribution(BaseDistribution):
"""
Expand All @@ -119,11 +139,11 @@ def __init__(self, name="Normal", mean_transform="none", var_transform="positive
super().__init__(name, param_names)

self.mean_transform = self.get_transform(mean_transform)
self.var_transform = self.get_transform(var_transform)
self.variance_transform = self.get_transform(var_transform)

def compute_loss(self, predictions, y_true):
mean = self.mean_transform(predictions[:, self.param_names.index("mean")])
variance = self.var_transform(
variance = self.variance_transform(
predictions[:, self.param_names.index("variance")]
)

Expand Down

0 comments on commit 6f55bd3

Please sign in to comment.