Skip to content

Commit

Permalink
switch to using env variables to set cache paths
Browse files Browse the repository at this point in the history
  • Loading branch information
jack89roberts committed Aug 1, 2024
1 parent 2c7d604 commit f600908
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 16 deletions.
4 changes: 3 additions & 1 deletion configs/experiment/tofu_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ full_data_config: tofu_full

# Baskerville kwargs
use_bask: true
model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget
model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/models
data_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/datasets
wandb_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/wandb
bask:
walltime: '0-5:0:0'
gpu_number: 1
Expand Down
11 changes: 2 additions & 9 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logging.getLogger().setLevel(logging.INFO)


def main(experiment_path, model_cache_dir):
def main(experiment_path):
# Step 0: get start time
start_time = get_datetime_str()

Expand Down Expand Up @@ -54,7 +54,6 @@ def main(experiment_path, model_cache_dir):
peft_kwargs=experiment_config.model_config.peft_kwargs,
**experiment_config.model_config.model_kwargs,
add_padding_token=experiment_config.model_config.add_padding_token,
cache_dir=model_cache_dir,
)

# Step 6: Load and prepreprocess data
Expand Down Expand Up @@ -165,15 +164,9 @@ def main(experiment_path, model_cache_dir):
help="Name of experiment yaml file contained in configs/experiment",
required=True,
)
parser.add_argument(
"--model_cache_dir",
type=str,
help="Folder path to cache downloaded model in",
required=True,
)

# Step 2: process kwargs
args = parser.parse_args()

# Step 3: pass to and call main
main(args.experiment_name, args.model_cache_dir)
main(args.experiment_name)
6 changes: 6 additions & 0 deletions src/arcsf/config/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def write_train_script(
array_number: int,
script_dir: Path,
model_cache_dir: str,
data_cache_dir: str,
wandb_cache_dir: str,
):
train_script = template.render(
job_name=f"{top_config_name}_{job_type}",
Expand All @@ -114,6 +116,8 @@ def write_train_script(
script_name="scripts/train.py",
experiment_file=f"{top_config_name}/{job_type}",
model_cache_dir=model_cache_dir,
data_cache_dir=data_cache_dir,
wandb_cache_dir=wandb_cache_dir,
)
# Create directory for train scripts if it doesn't exist
save_dir = script_dir / top_config_name
Expand Down Expand Up @@ -237,6 +241,8 @@ def generate_experiment_configs(top_config_name: str) -> None:
array_number=n_jobs - 1,
script_dir=script_dir,
model_cache_dir=top_config["model_cache_dir"],
data_cache_dir=top_config["data_cache_dir"],
wandb_cache_dir=top_config["wandb_cache_dir"],
)


Expand Down
6 changes: 5 additions & 1 deletion src/arcsf/config/jobscript_template.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ conda activate ${CONDA_ENV_PATH}

# Run script
echo "${SLURM_JOB_ID}: Job ${SLURM_ARRAY_TASK_ID} in the array"
python {{ script_name }} --experiment_name "{{ experiment_file }}_${SLURM_ARRAY_TASK_ID}" --model_cache_dir "{{ model_cache_dir }}
export HF_HOME="{{ model_cache_dir }}"
export HF_DATASETS_CACHE="{{ data_cache_dir }}"
export WANDB_CACHE_DIR="{{ wandb_cache_dir }}"
export WANDB_DATA_DIR="{{ wandb_cache_dir }}"
python {{ script_name }} --experiment_name "{{ experiment_file }}_${SLURM_ARRAY_TASK_ID}"
5 changes: 1 addition & 4 deletions src/arcsf/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ def load_model_and_tokenizer(
add_token_to_model = True

# Load Model
model = AutoModelForCausalLM.from_pretrained(
model_id,
**model_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)

# If padding token added, add to model too
if add_token_to_model:
Expand Down
4 changes: 3 additions & 1 deletion tests/configs/experiment/dummy_top_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ full_data_config: example_tofu_full

# Baskerville kwargs
use_bask: true
model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget
model_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/models
data_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/datasets
wandb_cache_dir: /bask/projects/v/vjgo8416-sltv-forget/caches/wandb
bask:
walltime: '0-5:0:0'
gpu_number: 1
Expand Down

0 comments on commit f600908

Please sign in to comment.