@@ -97,9 +97,10 @@ class Transformation:
97
97
98
98
Parameters
99
99
----------
100
- priors : dict, optional
100
+ priors : dict[str, Prior] , optional
101
101
Dictionary with the priors for the parameters of the function. The keys should be the
102
- parameter names and the values should be dictionaries with the distribution and kwargs.
102
+ parameter names and the values the priors. If not provided, it will use the default
103
+ priors from the subclass.
103
104
prefix : str, optional
104
105
The prefix for the variables that will be created. If not provided, it will use the prefix
105
106
from the subclass.
@@ -112,12 +113,43 @@ class Transformation:
112
113
lookup_name : str
113
114
114
115
def __init__ (
115
- self , priors : dict [str , Any | Prior ] | None = None , prefix : str | None = None
116
+ self , priors : dict [str , Prior ] | None = None , prefix : str | None = None
116
117
) -> None :
117
118
self ._checks ()
118
119
self .function_priors = priors # type: ignore
119
120
self .prefix = prefix or self .prefix
120
121
122
+ def __repr__ (self ) -> str :
123
+ return (
124
+ f"{ self .__class__ .__name__ } ("
125
+ f"prefix={ self .prefix !r} , "
126
+ f"priors={ self .function_priors } "
127
+ ")"
128
+ )
129
+
130
+ def to_dict (self ) -> dict [str , Any ]:
131
+ """Convert the transformation to a dictionary.
132
+
133
+ Returns
134
+ -------
135
+ dict
136
+ The dictionary defining the transformation.
137
+
138
+ """
139
+ return {
140
+ "lookup_name" : self .lookup_name ,
141
+ "prefix" : self .prefix ,
142
+ "priors" : {
143
+ key : value .to_json () for key , value in self .function_priors .items ()
144
+ },
145
+ }
146
+
147
+ def __eq__ (self , other : Any ) -> bool :
148
+ if not isinstance (other , self .__class__ ):
149
+ return False
150
+
151
+ return self .to_dict () == other .to_dict ()
152
+
121
153
@property
122
154
def function_priors (self ) -> dict [str , Prior ]:
123
155
return self ._function_priors
@@ -137,7 +169,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
137
169
138
170
Parameters
139
171
----------
140
- priors : dict
172
+ priors : dict[str, Prior]
141
173
Dictionary with the new priors for the parameters of the function.
142
174
143
175
Examples
@@ -150,6 +182,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
150
182
from pymc_marketing.prior import Prior
151
183
152
184
class MyTransformation(Transformation):
185
+ lookup_name: str = "my_transformation"
153
186
prefix: str = "transformation"
154
187
function = lambda x, lam: x * lam
155
188
default_priors = {"lam": Prior("Gamma", alpha=3, beta=1)}
@@ -200,6 +233,9 @@ def _has_all_attributes(self) -> None:
200
233
if not hasattr (self , "function" ):
201
234
raise NotImplementedError ("function must be implemented in the subclass" )
202
235
236
+ if not hasattr (self , "lookup_name" ):
237
+ raise NotImplementedError ("lookup_name must be implemented in the subclass" )
238
+
203
239
def _has_defaults_for_all_arguments (self ) -> None :
204
240
function_signature = signature (self .function )
205
241
0 commit comments