Skip to content

Commit 959e02b

Browse files
feat: add sklearn compatibility with __sklearn_tags__ method
This commit adds proper scikit-learn compatibility by implementing `__sklearn_tags__` in BaseModel and model subclasses. It also adds direct property accessors for model coefficients and improves model handling in SQL generation utilities.
1 parent 1d9d36a commit 959e02b

File tree

6 files changed

+231
-31
lines changed

6 files changed

+231
-31
lines changed

ml2sql/utils/modelling/models/base_model.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,44 @@ def get_params(self, deep=True):
171171
return self.params
172172
if hasattr(self.model, 'get_params'):
173173
return self.model.get_params(deep)
174-
return self.params
174+
return self.params
175+
176+
def __sklearn_tags__(self):
177+
"""
178+
Get the sklearn tags for this model.
179+
180+
This method is required for compatibility with scikit-learn's estimator interface.
181+
It delegates to the underlying model if it exists, otherwise returns a default set of tags.
182+
183+
Returns
184+
-------
185+
dict
186+
Dictionary of tags describing the model.
187+
"""
188+
if self.model is not None and hasattr(self.model, '__sklearn_tags__'):
189+
return self.model.__sklearn_tags__()
190+
elif self.model is not None and hasattr(self.model, '_get_tags'):
191+
# For older scikit-learn versions
192+
return self.model._get_tags()
193+
else:
194+
# Default tags
195+
return {
196+
'allow_nan': False,
197+
'binary_only': False,
198+
'multilabel': False,
199+
'multioutput': False,
200+
'multioutput_only': False,
201+
'no_validation': False,
202+
'non_deterministic': False,
203+
'pairwise': False,
204+
'preserves_dtype': [],
205+
'poor_score': False,
206+
'requires_fit': True,
207+
'requires_positive_X': False,
208+
'requires_positive_y': False,
209+
'requires_y': True,
210+
'stateless': False,
211+
'X_types': ['2darray'],
212+
'_skip_test': False,
213+
'_xfail_checks': False
214+
}

ml2sql/utils/modelling/models/ebm.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,46 @@ class EBMModel(BaseModel):
1616
model from the interpret package.
1717
"""
1818

19+
def __sklearn_tags__(self):
20+
"""
21+
Get the sklearn tags for this model.
22+
23+
This method is required for compatibility with scikit-learn's estimator interface.
24+
It delegates to the underlying model if it exists, otherwise returns a default set of tags.
25+
26+
Returns
27+
-------
28+
dict
29+
Dictionary of tags describing the model.
30+
"""
31+
if self.model is not None and hasattr(self.model, '__sklearn_tags__'):
32+
return self.model.__sklearn_tags__()
33+
elif self.model is not None and hasattr(self.model, '_get_tags'):
34+
# For older scikit-learn versions
35+
return self.model._get_tags()
36+
else:
37+
# Default tags
38+
return {
39+
'allow_nan': False,
40+
'binary_only': False,
41+
'multilabel': False,
42+
'multioutput': False,
43+
'multioutput_only': False,
44+
'no_validation': False,
45+
'non_deterministic': False,
46+
'pairwise': False,
47+
'preserves_dtype': [],
48+
'poor_score': False,
49+
'requires_fit': True,
50+
'requires_positive_X': False,
51+
'requires_positive_y': False,
52+
'requires_y': True,
53+
'stateless': False,
54+
'X_types': ['2darray'],
55+
'_skip_test': False,
56+
'_xfail_checks': False
57+
}
58+
1959
def train(self, X_train, y_train, model_type):
2060
"""
2161
Train an Explainable Boosting Machine (EBM) model on the given training data.
@@ -142,7 +182,7 @@ def trainModel(X_train, y_train, params, model_type):
142182
"""
143183
Legacy function for backward compatibility.
144184
145-
Creates and trains an EBMModel instance.
185+
Creates and trains an EBM model directly without using the EBMModel wrapper.
146186
147187
Parameters
148188
----------
@@ -157,11 +197,26 @@ def trainModel(X_train, y_train, params, model_type):
157197
158198
Returns
159199
-------
160-
clf : EBMModel
200+
clf : ExplainableBoostingClassifier or ExplainableBoostingRegressor
161201
Trained EBM model.
162202
"""
163-
model = EBMModel(params)
164-
return model.train(X_train, y_train, model_type).model
203+
if "feature_names" not in params.keys():
204+
params["feature_names"] = X_train.columns
205+
206+
if model_type == "regression":
207+
clf = ExplainableBoostingRegressor(**params)
208+
elif model_type == "classification":
209+
clf = ExplainableBoostingClassifier(**params)
210+
else:
211+
logger.warning("Only regression or classification available")
212+
raise ValueError("Invalid model_type. Must be 'regression' or 'classification'.")
213+
214+
clf.fit(X_train, y_train)
215+
216+
logger.info(f"Model params:\n {clf.get_params}")
217+
logger.info("Trained explainable boosting machine")
218+
219+
return clf
165220

166221

167222
def featureExplanationSave(clf, given_name, file_type):

ml2sql/utils/modelling/models/l_regression.py

+91-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,78 @@ class LinearRegressionModel(BaseModel):
1313
models from the interpret package.
1414
"""
1515

16+
def __sklearn_tags__(self):
17+
"""
18+
Get the sklearn tags for this model.
19+
20+
This method is required for compatibility with scikit-learn's estimator interface.
21+
It delegates to the underlying model if it exists, otherwise returns a default set of tags.
22+
23+
Returns
24+
-------
25+
dict
26+
Dictionary of tags describing the model.
27+
"""
28+
if self.model is not None and hasattr(self.model, '__sklearn_tags__'):
29+
return self.model.__sklearn_tags__()
30+
elif self.model is not None and hasattr(self.model, '_get_tags'):
31+
# For older scikit-learn versions
32+
return self.model._get_tags()
33+
else:
34+
# Default tags
35+
return {
36+
'allow_nan': False,
37+
'binary_only': False,
38+
'multilabel': False,
39+
'multioutput': False,
40+
'multioutput_only': False,
41+
'no_validation': False,
42+
'non_deterministic': False,
43+
'pairwise': False,
44+
'preserves_dtype': [],
45+
'poor_score': False,
46+
'requires_fit': True,
47+
'requires_positive_X': False,
48+
'requires_positive_y': False,
49+
'requires_y': True,
50+
'stateless': False,
51+
'X_types': ['2darray'],
52+
'_skip_test': False,
53+
'_xfail_checks': False
54+
}
55+
56+
@property
57+
def coef_(self):
58+
"""
59+
Get the coefficients of the model.
60+
61+
Returns
62+
-------
63+
numpy.ndarray
64+
Coefficients of the model.
65+
"""
66+
if self.model is None:
67+
raise ValueError("Model has not been trained yet.")
68+
if not hasattr(self.model, 'sk_model_'):
69+
raise AttributeError("Model does not have sk_model_ attribute.")
70+
return self.model.sk_model_.coef_
71+
72+
@property
73+
def intercept_(self):
74+
"""
75+
Get the intercept of the model.
76+
77+
Returns
78+
-------
79+
float or numpy.ndarray
80+
Intercept of the model.
81+
"""
82+
if self.model is None:
83+
raise ValueError("Model has not been trained yet.")
84+
if not hasattr(self.model, 'sk_model_'):
85+
raise AttributeError("Model does not have sk_model_ attribute.")
86+
return self.model.sk_model_.intercept_
87+
1688
def train(self, X_train, y_train, model_type):
1789
"""
1890
Train a Linear/Logistic Regression model on the given training data.
@@ -133,7 +205,7 @@ def trainModel(X_train, y_train, params, model_type):
133205
"""
134206
Legacy function for backward compatibility.
135207
136-
Creates and trains a LinearRegressionModel instance.
208+
Creates and trains a Linear/Logistic Regression model directly without using the LinearRegressionModel wrapper.
137209
138210
Parameters
139211
----------
@@ -151,8 +223,24 @@ def trainModel(X_train, y_train, params, model_type):
151223
clf : LinearRegression or LogisticRegression
152224
Trained model.
153225
"""
154-
model = LinearRegressionModel(params)
155-
return model.train(X_train, y_train, model_type).model
226+
if model_type == "regression":
227+
clf = LinearRegression(**params)
228+
clf_name = "Linear regression"
229+
elif model_type == "classification":
230+
clf = LogisticRegression(**params)
231+
# Hard code classes_
232+
clf.classes_ = list(set(y_train))
233+
clf_name = "Logistic regression"
234+
else:
235+
logger.warning("Only regression or classification available")
236+
raise ValueError("Invalid model_type. Must be 'regression' or 'classification'.")
237+
238+
clf.fit(X_train, y_train)
239+
240+
logger.info(f"Model non default params:\n {clf.kwargs}")
241+
logger.info(f"Trained {clf_name.lower()}")
242+
243+
return clf
156244

157245

158246
def featureExplanationSave(clf, given_name, file_type):

ml2sql/utils/output_scripts/decision_tree_as_code.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,20 @@ def tree_to_sql(tree, file=sys.stdout):
1313
1414
Parameters:
1515
-----------
16-
tree: sklearn decision tree model
16+
tree: sklearn decision tree model or DecisionTreeModel
1717
The decision tree to represent as an SQL function
1818
file: file object, optional (default=sys.stdout)
1919
The file to write the output to. If not specified, prints to console.
2020
"""
21+
# Check if this is our custom model wrapper
22+
if hasattr(tree, 'model'):
23+
actual_tree = tree.model
24+
else:
25+
actual_tree = tree
2126

22-
tree_ = tree.tree_
27+
tree_ = actual_tree.tree_
2328
feature_name = [
24-
tree.feature_names_in_[i] if i != _tree.TREE_UNDEFINED else "undefined!"
29+
actual_tree.feature_names_in_[i] if i != _tree.TREE_UNDEFINED else "undefined!"
2530
for i in tree_.feature
2631
]
2732

@@ -43,11 +48,11 @@ def recurse(node, depth):
4348
recurse(tree_.children_right[node], depth + 1)
4449
print(f"{indent}END", file=file)
4550
else:
46-
if hasattr(tree, "classes_"):
51+
if hasattr(actual_tree, "classes_"):
4752
class_values = tree_.value[node]
4853
samples = tree_.n_node_samples[node]
4954
max_value = int(np.max(class_values))
50-
predicted_class = tree.classes_[np.argmax(class_values)]
55+
predicted_class = actual_tree.classes_[np.argmax(class_values)]
5156

5257
if np.issubdtype(type(predicted_class), np.integer):
5358
print(
@@ -80,4 +85,4 @@ def save_model_and_extras(clf, model_name, post_params):
8085
logger.info("SQL version of decision tree saved")
8186

8287
# If you want to also print to console, you can call the function again without the file parameter
83-
# tree_to_sql(clf)
88+
# tree_to_sql(clf)

ml2sql/utils/output_scripts/ebm_as_code.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -840,11 +840,17 @@ def ebm_to_sql(model_name, df, classes, split=True):
840840

841841

842842
def save_model_and_extras(ebm, model_name, post_params):
843+
# Check if this is our custom model wrapper
844+
if hasattr(ebm, 'model'):
845+
actual_ebm = ebm.model
846+
else:
847+
actual_ebm = ebm
848+
843849
# extract lookup table from EBM
844-
lookup_df = extractLookupTable(ebm, post_params)
850+
lookup_df = extractLookupTable(actual_ebm, post_params)
845851
# In case of regression
846-
if not hasattr(ebm, "classes_"):
847-
ebm.classes_ = [0]
852+
if not hasattr(actual_ebm, "classes_"):
853+
actual_ebm.classes_ = [0]
848854
lookup_df["intercept"] = [lookup_df["intercept"]]
849855

850856
# Write printed output to file
@@ -854,7 +860,7 @@ def save_model_and_extras(ebm, model_name, post_params):
854860
with open(output_path, "w") as f:
855861
with redirect_stdout(f):
856862
model_name = Path(model_name).name
857-
ebm_to_sql(model_name, lookup_df, ebm.classes_, post_params["sql_split"])
863+
ebm_to_sql(model_name, lookup_df, actual_ebm.classes_, post_params["sql_split"])
858864
logger.info("SQL version of EBM saved")
859865

860866

@@ -867,4 +873,4 @@ def save_model_and_extras(ebm, model_name, post_params):
867873
"file_type": "png",
868874
"sql_decimals": 15,
869875
}
870-
save_model_and_extras(ebm, model_name, post_params)
876+
save_model_and_extras(ebm, model_name, post_params)

ml2sql/utils/output_scripts/l_regression_as_code.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def extract_parameters(model):
1111
Extracts model_type, features, coefficients, and intercept from a trained logistic regression model.
1212
1313
Parameters:
14-
- trained_model: The trained logistic regression model object.
14+
- model: The trained model object, either a direct scikit-learn model or our custom wrapper.
1515
1616
Returns:
1717
- model_type: String, either regression of classification
@@ -20,30 +20,36 @@ def extract_parameters(model):
2020
- intercept: Intercept of the logistic regression model.
2121
"""
2222
try:
23+
# Check if this is our custom model wrapper
24+
if hasattr(model, 'model'):
25+
actual_model = model.model
26+
else:
27+
actual_model = model
28+
2329
# Extract model type
24-
if model.__class__.__name__ == "LinearRegression":
30+
if actual_model.__class__.__name__ == "LinearRegression":
2531
model_type = "regression"
2632
pclasses = None
27-
elif len(model.classes_) > 2:
33+
elif len(actual_model.classes_) > 2:
2834
model_type = "multiclass"
29-
pclasses = model.classes_
30-
elif len(model.classes_) == 2:
35+
pclasses = actual_model.classes_
36+
elif len(actual_model.classes_) == 2:
3137
model_type = "binary"
32-
pclasses = model.classes_
38+
pclasses = actual_model.classes_
3339

3440
# Extract features
35-
features = model.feature_names_in_
41+
features = actual_model.feature_names_in_
3642

3743
if model_type == "binary":
38-
coefficients = model.sk_model_.coef_[0]
44+
coefficients = actual_model.coef_[0]
3945
else:
40-
coefficients = model.sk_model_.coef_
46+
coefficients = actual_model.coef_
4147

4248
# Extract intercept
4349
if model_type == "binary":
44-
intercept = model.sk_model_.intercept_[0]
50+
intercept = actual_model.intercept_[0]
4551
else:
46-
intercept = model.sk_model_.intercept_
52+
intercept = actual_model.intercept_
4753

4854
return model_type, pclasses, features, coefficients, intercept
4955

@@ -191,4 +197,4 @@ def save_model_and_extras(clf, model_name, post_params):
191197
intercept,
192198
post_params,
193199
)
194-
logger.info("SQL version of logistic/linear regression saved")
200+
logger.info("SQL version of logistic/linear regression saved")

0 commit comments

Comments
 (0)