diff --git a/mambular/models/sklearn_distributional.py b/mambular/models/sklearn_distributional.py index 8a4cca6..63e9727 100644 --- a/mambular/models/sklearn_distributional.py +++ b/mambular/models/sklearn_distributional.py @@ -528,7 +528,7 @@ def fit( return self - def predict(self, X): + def predict(self, X, raw=False): """ Predicts target values for the given input samples using the fitted model. @@ -565,8 +565,12 @@ def predict(self, X): with torch.no_grad(): predictions = self.model(cat_tensors, num_tensors) + if not raw: + return self.model.family(predictions).cpu().numpy() + # Convert predictions to NumPy array and return - return predictions.cpu().numpy() + else: + return predictions.cpu().numpy() def evaluate(self, X, y_true, metrics=None, distribution_family=None): """ @@ -600,7 +604,7 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None): metrics = self.get_default_metrics(distribution_family) # Make predictions - predictions = self.predict(X) + predictions = self.predict(X, raw=False) # Initialize dictionary to store results scores = {} diff --git a/mambular/utils/distributions.py b/mambular/utils/distributions.py index 81d569a..58378e3 100644 --- a/mambular/utils/distributions.py +++ b/mambular/utils/distributions.py @@ -3,7 +3,7 @@ import numpy as np -class BaseDistribution: +class BaseDistribution(torch.nn.Module): """ The base class for various statistical distributions, providing a common interface and utilities. @@ -23,6 +23,8 @@ class BaseDistribution: """ def __init__(self, name, param_names): + super(BaseDistribution, self).__init__() + self._name = name self.param_names = param_names self.param_count = len(param_names) @@ -97,6 +99,24 @@ def evaluate_nll(self, y_true, y_pred): "NLL": nll_loss_tensor.detach().numpy(), } + def forward(self, predictions): + """ + Apply the appropriate transformations to the predicted parameters. + + Parameters: + predictions (torch.Tensor): The predicted parameters of the distribution. + + Returns: + torch.Tensor: A tensor with transformed parameters. + """ + transformed_params = [] + for idx, param_name in enumerate(self.param_names): + transform_func = self.get_transform( + getattr(self, f"{param_name}_transform", "none") + ) + transformed_params.append(transform_func(predictions[:, idx]).unsqueeze(1)) + return torch.cat(transformed_params, dim=1) + class NormalDistribution(BaseDistribution): """ @@ -119,11 +139,11 @@ def __init__(self, name="Normal", mean_transform="none", var_transform="positive super().__init__(name, param_names) self.mean_transform = self.get_transform(mean_transform) - self.var_transform = self.get_transform(var_transform) + self.variance_transform = self.get_transform(var_transform) def compute_loss(self, predictions, y_true): mean = self.mean_transform(predictions[:, self.param_names.index("mean")]) - variance = self.var_transform( + variance = self.variance_transform( predictions[:, self.param_names.index("variance")] )