Skip to content

Commit

Permalink
fix compatibility issues with nyuv2 experiments (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke authored Dec 3, 2024
1 parent a9b5173 commit 299a481
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 30 deletions.
26 changes: 0 additions & 26 deletions config/llama_weighted_average.yaml

This file was deleted.

6 changes: 5 additions & 1 deletion config/nyuv2_config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
defaults:
- hydra: default
- fabric: auto
- modelpool: nyuv2_modelpool
- method: simple_average
- taskpool: nyuv2_taskpool
- _self_

_target_: fusion_bench.programs.FabricModelFusionProgram
_recursive_: false

fast_dev_run: false # Run a single batch of data to test the model or method
use_lightning: true # Use the fabric to run the experiment
print_config: true # Print the configuration to the console
save_report: false # path to save the result report
fabric: null
trainer:
devices: 1
2 changes: 1 addition & 1 deletion fusion_bench/compat/modelpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ModelPoolFactory:
"""

_modelpool = {
"NYUv2ModelPool": ".nyuv2_modelpool.NYUv2ModelPool",
"NYUv2ModelPool": "fusion_bench.modelpool.nyuv2_modelpool.NYUv2ModelPool",
"huggingface_clip_vision": HuggingFaceClipVisionPool,
"HF_GPT2ForSequenceClassification": GPT2ForSequenceClassificationPool,
"AutoModelPool": ".huggingface_automodel.AutoModelPool",
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/compat/taskpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TaskPoolFactory:
"dummy": DummyTaskPool,
"clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
"FlanT5GLUETextGenerationTaskPool": ".flan_t5_glue_text_generation.FlanT5GLUETextGenerationTaskPool",
"NYUv2TaskPool": ".nyuv2_taskpool.NYUv2TaskPool",
"NYUv2TaskPool": "fusion_bench.taskpool.nyuv2_taskpool.NYUv2TaskPool",
}

@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion fusion_bench/programs/fabric_fusion_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ def run(self):
self.save_merged_model(merged_model)
if self.taskpool is not None:
report = self.evaluate_merged_model(self.taskpool, merged_model)
print_json(report, print_type=False)
try:
print_json(report, print_type=False)
except Exception as e:
log.warning(f"Failed to pretty print the report: {e}")
print(report)
if self.report_save_path is not None:
# save report (Dict) to a file
# if the directory of `save_report` does not exists, create it
Expand Down
2 changes: 2 additions & 0 deletions fusion_bench/taskpool/nyuv2_taskpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,6 @@ def evaluate(self, encoder: ResnetDilated):
num_workers=self.config.num_workers,
)
report = self.trainer.validate(model, val_loader)
if isinstance(report, list) and len(report) == 1:
report = report[0]
return report

0 comments on commit 299a481

Please sign in to comment.