diff --git a/configs/experiment/tofu_test.yaml b/configs/experiment/tofu_test.yaml index 8d56ad17..44996f36 100644 --- a/configs/experiment/tofu_test.yaml +++ b/configs/experiment/tofu_test.yaml @@ -27,7 +27,9 @@ full_data_config: tofu_full # Baskerville kwargs use_bask: true -model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget +model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/models +data_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/datasets +wandb_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/wandb bask: walltime: '0-5:0:0' gpu_number: 1 diff --git a/scripts/train.py b/scripts/train.py index 72d5e512..c0fa4829 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -26,7 +26,7 @@ logging.getLogger().setLevel(logging.INFO) -def main(experiment_path, model_cache_dir): +def main(experiment_path): # Step 0: get start time start_time = get_datetime_str() @@ -54,7 +54,6 @@ def main(experiment_path, model_cache_dir): peft_kwargs=experiment_config.model_config.peft_kwargs, **experiment_config.model_config.model_kwargs, add_padding_token=experiment_config.model_config.add_padding_token, - cache_dir=model_cache_dir, ) # Step 6: Load and prepreprocess data @@ -165,15 +164,9 @@ def main(experiment_path, model_cache_dir): help="Name of experiment yaml file contained in configs/experiment", required=True, ) - parser.add_argument( - "--model_cache_dir", - type=str, - help="Folder path to cache downloaded model in", - required=True, - ) # Step 2: process kwargs args = parser.parse_args() # Step 3: pass to and call main - main(args.experiment_name, args.model_cache_dir) + main(args.experiment_name) diff --git a/src/arcsf/config/experiment.py b/src/arcsf/config/experiment.py index b8637163..e4acee2c 100644 --- a/src/arcsf/config/experiment.py +++ b/src/arcsf/config/experiment.py @@ -104,6 +104,8 @@ def write_train_script( array_number: int, script_dir: Path, model_cache_dir: str, + data_cache_dir: str, + wandb_cache_dir: str, ): train_script = template.render( job_name=f"{top_config_name}_{job_type}", @@ -114,6 +116,8 @@ def write_train_script( script_name="scripts/train.py", experiment_file=f"{top_config_name}/{job_type}", model_cache_dir=model_cache_dir, + data_cache_dir=data_cache_dir, + wandb_cache_dir=wandb_cache_dir, ) # Create directory for train scripts if it doesn't exist save_dir = script_dir / top_config_name @@ -237,6 +241,8 @@ def generate_experiment_configs(top_config_name: str) -> None: array_number=n_jobs - 1, script_dir=script_dir, model_cache_dir=top_config["model_cache_dir"], + data_cache_dir=top_config["data_cache_dir"], + wandb_cache_dir=top_config["wandb_cache_dir"], ) diff --git a/src/arcsf/config/jobscript_template.sh b/src/arcsf/config/jobscript_template.sh index a9c4d8b4..606672ef 100644 --- a/src/arcsf/config/jobscript_template.sh +++ b/src/arcsf/config/jobscript_template.sh @@ -21,4 +21,8 @@ conda activate ${CONDA_ENV_PATH} # Run script echo "${SLURM_JOB_ID}: Job ${SLURM_ARRAY_TASK_ID} in the array" -python {{ script_name }} --experiment_name "{{ experiment_file }}_${SLURM_ARRAY_TASK_ID}" --model_cache_dir "{{ model_cache_dir }} +export HF_HOME="{{ model_cache_dir }}" +export HF_DATASETS_CACHE="{{ data_cache_dir }}" +export WANDB_CACHE_DIR="{{ wandb_cache_dir }}" +export WANDB_DATA_DIR="{{ wandb_cache_dir }}" +python {{ script_name }} --experiment_name "{{ experiment_file }}_${SLURM_ARRAY_TASK_ID}" diff --git a/src/arcsf/models/model.py b/src/arcsf/models/model.py index 7610fbb1..b80be767 100644 --- a/src/arcsf/models/model.py +++ b/src/arcsf/models/model.py @@ -42,10 +42,7 @@ def load_model_and_tokenizer( add_token_to_model = True # Load Model - model = AutoModelForCausalLM.from_pretrained( - model_id, - **model_kwargs, - ) + model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) # If padding token added, add to model too if add_token_to_model: diff --git a/tests/configs/experiment/dummy_top_config.yaml b/tests/configs/experiment/dummy_top_config.yaml index 1b3646fa..5b4730ab 100644 --- a/tests/configs/experiment/dummy_top_config.yaml +++ b/tests/configs/experiment/dummy_top_config.yaml @@ -17,7 +17,9 @@ full_data_config: example_tofu_full # Baskerville kwargs use_bask: true -model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget +model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/models +data_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/datasets +wandb_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/wandb bask: walltime: '0-5:0:0' gpu_number: 1