Skip to content

Commit 4c8f998

Browse files
authored
fix: improve_execution_time_in_kaggle_loop (#279)
* improve_execution_time_in_kaggle_loop * fix CI * fix CI * fix CI
1 parent 26352e1 commit 4c8f998

File tree

18 files changed

+143
-82
lines changed

18 files changed

+143
-82
lines changed

rdagent/app/kaggle/loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from rdagent.components.workflow.conf import BasePropSetting
88
from rdagent.components.workflow.rd_loop import RDLoop
99
from rdagent.core.developer import Developer
10-
from rdagent.core.exception import ModelEmptyError
10+
from rdagent.core.exception import FactorEmptyError, ModelEmptyError
1111
from rdagent.core.proposal import (
1212
Hypothesis2Experiment,
1313
HypothesisExperiment2Feedback,
@@ -71,7 +71,7 @@ def running(self, prev_out: dict[str, Any]):
7171
logger.log_object(exp, tag="runner result")
7272
return exp
7373

74-
skip_loop_error = (ModelEmptyError,)
74+
skip_loop_error = (ModelEmptyError, FactorEmptyError)
7575

7676

7777
def main(path=None, step_n=None, competition=None):

rdagent/app/qlib_rd_loop/factor_from_report.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
extract_first_page_screenshot_from_pdf,
1212
load_and_process_pdfs_by_langchain,
1313
)
14-
from rdagent.components.workflow.rd_loop import RDLoop
15-
from rdagent.core.exception import FactorEmptyError
1614
from rdagent.core.prompts import Prompts
1715
from rdagent.core.proposal import Hypothesis
1816
from rdagent.log import rdagent_logger as logger

rdagent/components/coder/factor_coder/factor.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,6 @@ def __init__(
8787
self.executed_factor_value_dataframe = executed_factor_value_dataframe
8888
self.raise_exception = raise_exception
8989

90-
@staticmethod
91-
def link_data_to_workspace(data_path: Path, workspace_path: Path):
92-
data_path = Path(data_path).absolute() # in case of relative path that will be invalid when we change cwd.
93-
workspace_path = Path(workspace_path)
94-
for data_file_path in data_path.iterdir():
95-
workspace_data_file_path = workspace_path / data_file_path.name
96-
if workspace_data_file_path.exists():
97-
workspace_data_file_path.unlink()
98-
subprocess.run(
99-
["ln", "-s", data_file_path, workspace_data_file_path],
100-
check=False,
101-
)
102-
10390
def execute(self, store_result: bool = False, data_type: str = "Debug") -> Tuple[str, pd.DataFrame]:
10491
"""
10592
execute the implementation and get the factor value by the following steps:
@@ -154,7 +141,7 @@ def execute(self, store_result: bool = False, data_type: str = "Debug") -> Tuple
154141
source_data_path.mkdir(exist_ok=True, parents=True)
155142
code_path = self.workspace_path / f"factor.py"
156143

157-
self.link_data_to_workspace(source_data_path, self.workspace_path)
144+
self.link_all_files_in_folder_to_workspace(source_data_path, self.workspace_path)
158145

159146
execution_feedback = self.FB_EXECUTION_SUCCEEDED
160147
execution_success = False

rdagent/components/coder/factor_coder/factor_execution_template.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import numpy as np
44
import pandas as pd
55
from factor import feature_engineering_cls
66

7-
if os.path.exists("valid.pkl"):
8-
valid_df = pd.read_pickle("valid.pkl")
7+
if os.path.exists("X_valid.pkl"):
8+
valid_df = pd.read_pickle("X_valid.pkl").head(1000)
99
else:
1010
raise FileNotFoundError("No valid data found.")
1111

rdagent/core/experiment.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import os
34
import shutil
45
import uuid
56
from abc import ABC, abstractmethod
7+
from collections.abc import Sequence
68
from copy import deepcopy
79
from pathlib import Path
8-
from typing import Any, Generic, Sequence, TypeVar
10+
from typing import Any, Generic, TypeVar
911

1012
from rdagent.core.conf import RD_AGENT_SETTINGS
1113

@@ -111,6 +113,16 @@ def prepare(self) -> None:
111113
"""
112114
self.workspace_path.mkdir(parents=True, exist_ok=True)
113115

116+
@staticmethod
117+
def link_all_files_in_folder_to_workspace(data_path: Path, workspace_path: Path) -> None:
118+
data_path = Path(data_path).absolute() # in case of relative path that will be invalid when we change cwd.
119+
workspace_path = Path(workspace_path)
120+
for data_file_path in data_path.iterdir():
121+
workspace_data_file_path = workspace_path / data_file_path.name
122+
if workspace_data_file_path.exists():
123+
workspace_data_file_path.unlink()
124+
os.symlink(data_file_path, workspace_data_file_path)
125+
114126
def inject_code(self, **files: str) -> None:
115127
"""
116128
Inject the code into the folder.

rdagent/core/prompts.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from pathlib import Path # noqa: I001
2-
from typing import Dict
32

43
import yaml
54

65
from rdagent.core.utils import SingletonBaseClass
76

87

9-
class Prompts(SingletonBaseClass, Dict[str, str]):
8+
class Prompts(SingletonBaseClass, dict[str, str]):
109
def __init__(self, file_path: Path) -> None:
1110
super().__init__()
1211
with file_path.open(encoding="utf8") as file:

rdagent/core/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
3030
raise RDAgentException(exception_message)
3131
class_name = [(-1, f"{cls.__module__}.{cls.__name__}")]
3232
args_l = [(i, args[i]) for i in args]
33-
kwargs_l = list(sorted(kwargs.items()))
33+
kwargs_l = sorted(kwargs.items())
3434
all_args = class_name + args_l + kwargs_l
3535
kwargs_hash = hash(tuple(all_args))
3636
if kwargs_hash not in cls._instance_dict:

rdagent/scenarios/data_mining/proposal/model_proposal.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,18 @@ def __init__(self, scen: Scenario) -> Tuple[dict, bool]:
3535
super().__init__(scen)
3636

3737
def prepare_context(self, trace: Trace) -> Tuple[dict, bool]:
38-
hypothesis_feedback = (
39-
Environment(undefined=StrictUndefined)
40-
.from_string(prompt_dict["hypothesis_and_feedback"])
41-
.render(trace=trace)
38+
hypothesis_and_feedback = (
39+
(
40+
Environment(undefined=StrictUndefined)
41+
.from_string(prompt_dict["hypothesis_and_feedback"])
42+
.render(trace=trace)
43+
)
44+
if len(trace.hist) > 0
45+
else "No previous hypothesis and feedback available since it's the first round."
4246
)
4347
context_dict = {
44-
"hypothesis_and_feedback": hypothesis_feedback,
45-
"RAG": "",
48+
"hypothesis_and_feedback": hypothesis_and_feedback,
49+
"RAG": None,
4650
"hypothesis_output_format": prompt_dict["hypothesis_output_format"],
4751
"hypothesis_specification": prompt_dict["model_hypothesis_specification"],
4852
}
@@ -67,9 +71,13 @@ def prepare_context(self, hypothesis: Hypothesis, trace: Trace) -> Tuple[dict, b
6771
experiment_output_format = prompt_dict["model_experiment_output_format"]
6872

6973
hypothesis_and_feedback = (
70-
Environment(undefined=StrictUndefined)
71-
.from_string(prompt_dict["hypothesis_and_feedback"])
72-
.render(trace=trace)
74+
(
75+
Environment(undefined=StrictUndefined)
76+
.from_string(prompt_dict["hypothesis_and_feedback"])
77+
.render(trace=trace)
78+
)
79+
if len(trace.hist) > 0
80+
else "No previous hypothesis and feedback available since it's the first round."
7381
)
7482

7583
experiment_list: List[ModelExperiment] = [t[1] for t in trace.hist]
@@ -84,7 +92,7 @@ def prepare_context(self, hypothesis: Hypothesis, trace: Trace) -> Tuple[dict, b
8492
"hypothesis_and_feedback": hypothesis_and_feedback,
8593
"experiment_output_format": experiment_output_format,
8694
"target_list": model_list,
87-
"RAG": ...,
95+
"RAG": None,
8896
}, True
8997

9098
def convert_response(self, response: str, trace: Trace) -> ModelExperiment:

rdagent/scenarios/kaggle/experiment/meta_tpl/cross_validation_tpl.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
import numpy as np
44
import pandas as pd
5-
import xgboost as xgb
6-
from sklearn.metrics import accuracy_score, matthews_corrcoef
75
from sklearn.model_selection import KFold
8-
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
6+
from sklearn.preprocessing import LabelEncoder
97

108
from rdagent.scenarios.kaggle.experiment.meta_tpl.fea_share_preprocess import preprocess
119

rdagent/scenarios/kaggle/experiment/meta_tpl/fea_share_preprocess.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pandas as pd
24
from sklearn.compose import ColumnTransformer
35
from sklearn.impute import SimpleImputer
@@ -82,6 +84,15 @@ def preprocess_script():
8284
"""
8385
This method applies the preprocessing steps to the training, validation, and test datasets.
8486
"""
87+
if os.path.exists("X_train.pkl"):
88+
X_train = pd.read_pickle("X_train.pkl")
89+
X_valid = pd.read_pickle("X_valid.pkl")
90+
y_train = pd.read_pickle("y_train.pkl")
91+
y_valid = pd.read_pickle("y_valid.pkl")
92+
X_test = pd.read_pickle("X_test.pkl")
93+
passenger_ids = pd.read_pickle("passenger_ids.pkl")
94+
95+
return X_train, X_valid, y_train, y_valid, X_test, passenger_ids
8596
X_train, X_valid, y_train, y_valid = prepreprocess()
8697

8798
# Fit the preprocessor on the training data

rdagent/scenarios/kaggle/experiment/meta_tpl/model/model_rf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fit(X_train: pd.DataFrame, y_train: pd.Series, X_valid: pd.DataFrame, y_vali
2323
Define and train the Random Forest model. Merge feature selection into the pipeline.
2424
"""
2525
# Initialize the Random Forest model
26-
model = RandomForestClassifier(n_estimators=100, random_state=32)
26+
model = RandomForestClassifier(n_estimators=100, random_state=32, n_jobs=-1)
2727

2828
# Select features (if any feature selection is needed)
2929
X_train_selected = select(X_train)

rdagent/scenarios/kaggle/experiment/meta_tpl/model/model_xgb.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,23 @@
66
import xgboost as xgb
77

88

9-
def select(X):
10-
"""
11-
Select relevant features. To be used in fit & predict function
12-
"""
9+
def select(X: pd.DataFrame) -> pd.DataFrame:
10+
# Ignore feature selection logic
1311
return X
1412

1513

1614
def fit(X_train: pd.DataFrame, y_train: pd.DataFrame, X_valid: pd.DataFrame, y_valid: pd.DataFrame):
1715
"""Define and train the model. Merge feature_select"""
16+
X_train = select(X_train)
17+
X_valid = select(X_valid)
1818
dtrain = xgb.DMatrix(X_train, label=y_train)
1919
dvalid = xgb.DMatrix(X_valid, label=y_valid)
2020

2121
# TODO: for quick running....
22-
params = {}
23-
num_round = 50
22+
params = {
23+
"nthred": -1,
24+
}
25+
num_round = 200
2426

2527
evallist = [(dtrain, "train"), (dvalid, "eval")]
2628
bst = xgb.train(params, dtrain, num_round, evallist)
@@ -32,6 +34,7 @@ def predict(model, X):
3234
"""
3335
Keep feature select's consistency.
3436
"""
37+
X = select(X)
3538
dtest = xgb.DMatrix(X)
3639
y_pred_prob = model.predict(dtest)
3740
return y_pred_prob

rdagent/scenarios/kaggle/experiment/prompts.yaml

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ kg_description_template:
1010
"Target Description": "A description of the target variable to be predicted",
1111
"Competition Features": "A dict of relevant features used in the competition and their descriptions (if available)", # if you are not sure about the meaning of the feature, please add a (guess) before the description. Importantly, your feature name should be exactly the same as the feature name in the dataset!
1212
}
13+
Since these might be very similar column names in data like one_hot_encoded columns, you can use some regex to group them together.
1314
1415
1516
user: |-
@@ -144,7 +145,7 @@ kg_model_interface: |-
144145
from xgboost import DMatrix
145146
146147
147-
def select(self, X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
148+
def select(X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
148149
149150
150151
def fit(
@@ -178,7 +179,7 @@ kg_model_interface: |-
178179
from sklearn.metrics import accuracy_score
179180
180181
181-
def select(self, X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
182+
def select(X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
182183
183184
184185
def fit(
@@ -207,7 +208,7 @@ kg_model_interface: |-
207208
from lightgbm import LGBMClassifier, LGBMRegressor
208209
209210
210-
def select(self, X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
211+
def select(X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
211212
212213
213214
def fit(
@@ -247,7 +248,7 @@ kg_model_interface: |-
247248
return x
248249
249250
250-
def select(self, X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
251+
def select(X: pd.DataFrame) -> pd.DataFrame: ... # Implement feature selection logic
251252
252253
253254
def fit(X_train: pd.DataFrame, y_train: pd.DataFrame, X_valid: pd.DataFrame, y_valid: pd.DataFrame) -> torch.nn.Module:

rdagent/scenarios/kaggle/experiment/scenario.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import io
12
import json
3+
import pickle
24
from pathlib import Path
35

46
import pandas as pd
@@ -93,9 +95,12 @@ def background(self) -> str:
9395
def source_data(self) -> str:
9496
data_folder = Path(FACTOR_IMPLEMENT_SETTINGS.data_folder) / self.competition
9597

96-
if (data_folder / "valid.pkl").exists():
97-
X_valid = pd.read_pickle(data_folder / "valid.pkl")
98-
return X_valid.head()
98+
if (data_folder / "X_valid.pkl").exists():
99+
X_valid = pd.read_pickle(data_folder / "X_valid.pkl")
100+
buffer = io.StringIO()
101+
X_valid.info(verbose=True, buf=buffer, show_counts=True)
102+
data_info = buffer.getvalue()
103+
return data_info
99104

100105
preprocess_experiment = KGFactorExperiment([])
101106
(
@@ -108,8 +113,17 @@ def source_data(self) -> str:
108113
) = preprocess_experiment.experiment_workspace.generate_preprocess_data()
109114

110115
data_folder.mkdir(exist_ok=True, parents=True)
111-
X_valid.to_pickle(data_folder / "valid.pkl")
112-
return X_valid.head()
116+
pickle.dump(X_train, open(data_folder / "X_train.pkl", "wb"))
117+
pickle.dump(X_valid, open(data_folder / "X_valid.pkl", "wb"))
118+
pickle.dump(y_train, open(data_folder / "y_train.pkl", "wb"))
119+
pickle.dump(y_valid, open(data_folder / "y_valid.pkl", "wb"))
120+
pickle.dump(X_test, open(data_folder / "X_test.pkl", "wb"))
121+
pickle.dump(passenger_ids, open(data_folder / "passenger_ids.pkl", "wb"))
122+
123+
buffer = io.StringIO()
124+
X_valid.info(verbose=True, buf=buffer, show_counts=True)
125+
data_info = buffer.getvalue()
126+
return data_info
113127

114128
@property
115129
def output_format(self) -> str:

rdagent/scenarios/kaggle/experiment/workspace.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66

77
from rdagent.app.kaggle.conf import KAGGLE_IMPLEMENT_SETTING
8+
from rdagent.components.coder.factor_coder.config import FACTOR_IMPLEMENT_SETTINGS
89
from rdagent.core.experiment import FBWorkspace
910
from rdagent.log import rdagent_logger as logger
1011
from rdagent.utils.env import KGDockerEnv
@@ -58,6 +59,11 @@ def generate_preprocess_data(
5859

5960
def execute(self, run_env: dict = {}, *args, **kwargs) -> str:
6061
logger.info(f"Running the experiment in {self.workspace_path}")
62+
63+
# link the data to the workspace to speed up the preprocessing
64+
source_data_path = Path(FACTOR_IMPLEMENT_SETTINGS.data_folder) / KAGGLE_IMPLEMENT_SETTING.competition
65+
self.link_all_files_in_folder_to_workspace(source_data_path, self.workspace_path)
66+
6167
kgde = KGDockerEnv(KAGGLE_IMPLEMENT_SETTING.competition)
6268
kgde.prepare()
6369

0 commit comments

Comments
 (0)