@@ -13,6 +13,78 @@ class LinearRegressionModel(BaseModel):
13
13
models from the interpret package.
14
14
"""
15
15
16
+ def __sklearn_tags__ (self ):
17
+ """
18
+ Get the sklearn tags for this model.
19
+
20
+ This method is required for compatibility with scikit-learn's estimator interface.
21
+ It delegates to the underlying model if it exists, otherwise returns a default set of tags.
22
+
23
+ Returns
24
+ -------
25
+ dict
26
+ Dictionary of tags describing the model.
27
+ """
28
+ if self .model is not None and hasattr (self .model , '__sklearn_tags__' ):
29
+ return self .model .__sklearn_tags__ ()
30
+ elif self .model is not None and hasattr (self .model , '_get_tags' ):
31
+ # For older scikit-learn versions
32
+ return self .model ._get_tags ()
33
+ else :
34
+ # Default tags
35
+ return {
36
+ 'allow_nan' : False ,
37
+ 'binary_only' : False ,
38
+ 'multilabel' : False ,
39
+ 'multioutput' : False ,
40
+ 'multioutput_only' : False ,
41
+ 'no_validation' : False ,
42
+ 'non_deterministic' : False ,
43
+ 'pairwise' : False ,
44
+ 'preserves_dtype' : [],
45
+ 'poor_score' : False ,
46
+ 'requires_fit' : True ,
47
+ 'requires_positive_X' : False ,
48
+ 'requires_positive_y' : False ,
49
+ 'requires_y' : True ,
50
+ 'stateless' : False ,
51
+ 'X_types' : ['2darray' ],
52
+ '_skip_test' : False ,
53
+ '_xfail_checks' : False
54
+ }
55
+
56
+ @property
57
+ def coef_ (self ):
58
+ """
59
+ Get the coefficients of the model.
60
+
61
+ Returns
62
+ -------
63
+ numpy.ndarray
64
+ Coefficients of the model.
65
+ """
66
+ if self .model is None :
67
+ raise ValueError ("Model has not been trained yet." )
68
+ if not hasattr (self .model , 'sk_model_' ):
69
+ raise AttributeError ("Model does not have sk_model_ attribute." )
70
+ return self .model .sk_model_ .coef_
71
+
72
+ @property
73
+ def intercept_ (self ):
74
+ """
75
+ Get the intercept of the model.
76
+
77
+ Returns
78
+ -------
79
+ float or numpy.ndarray
80
+ Intercept of the model.
81
+ """
82
+ if self .model is None :
83
+ raise ValueError ("Model has not been trained yet." )
84
+ if not hasattr (self .model , 'sk_model_' ):
85
+ raise AttributeError ("Model does not have sk_model_ attribute." )
86
+ return self .model .sk_model_ .intercept_
87
+
16
88
def train (self , X_train , y_train , model_type ):
17
89
"""
18
90
Train a Linear/Logistic Regression model on the given training data.
@@ -133,7 +205,7 @@ def trainModel(X_train, y_train, params, model_type):
133
205
"""
134
206
Legacy function for backward compatibility.
135
207
136
- Creates and trains a LinearRegressionModel instance .
208
+ Creates and trains a Linear/Logistic Regression model directly without using the LinearRegressionModel wrapper .
137
209
138
210
Parameters
139
211
----------
@@ -151,8 +223,24 @@ def trainModel(X_train, y_train, params, model_type):
151
223
clf : LinearRegression or LogisticRegression
152
224
Trained model.
153
225
"""
154
- model = LinearRegressionModel (params )
155
- return model .train (X_train , y_train , model_type ).model
226
+ if model_type == "regression" :
227
+ clf = LinearRegression (** params )
228
+ clf_name = "Linear regression"
229
+ elif model_type == "classification" :
230
+ clf = LogisticRegression (** params )
231
+ # Hard code classes_
232
+ clf .classes_ = list (set (y_train ))
233
+ clf_name = "Logistic regression"
234
+ else :
235
+ logger .warning ("Only regression or classification available" )
236
+ raise ValueError ("Invalid model_type. Must be 'regression' or 'classification'." )
237
+
238
+ clf .fit (X_train , y_train )
239
+
240
+ logger .info (f"Model non default params:\n { clf .kwargs } " )
241
+ logger .info (f"Trained { clf_name .lower ()} " )
242
+
243
+ return clf
156
244
157
245
158
246
def featureExplanationSave (clf , given_name , file_type ):
0 commit comments