-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigs.py
154 lines (129 loc) · 4.9 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import dataclasses
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Iterable, Tuple, Union, cast, List
from omegaconf import OmegaConf
DataClass = Any
DataClassType = Any
import dataclasses
from typing import Any, Dict, Type, TypeVar, Union
T = TypeVar("T")
class DataclassArgParser:
"""Utility class to populate dataclasses from dictionaries."""
@staticmethod
def _populate_dataclass_from_dict(cls: Type[T], inputs: Dict[str, Any]) -> T:
"""
Populates a dataclass instance from a dictionary.
Handles nested dataclasses by recursively populating them.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")
# Extract field names and types
field_names = {field.name: field.type for field in dataclasses.fields(cls)}
# Create an instance of the dataclass
obj = cls()
for key, value in inputs.items():
if key in field_names:
field_type = field_names[key]
if dataclasses.is_dataclass(field_type):
# Recursively populate nested dataclasses
nested_obj = DataclassArgParser._populate_dataclass_from_dict(
field_type, value
)
setattr(obj, key, nested_obj)
else:
# Assign value directly for non-dataclass fields
setattr(obj, key, value)
return obj
@staticmethod
def _populate_dataclass_from_flat_dict(cls: Type[T], inputs: Dict[str, Any]) -> T:
"""
Populates a dataclass instance from a flat dictionary.
Nested fields are expected to use dot notation (e.g., 'field.subfield').
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")
# Extract field names and types
field_names = {field.name: field.type for field in dataclasses.fields(cls)}
# Create an instance of the dataclass
obj = cls()
for key, value in inputs.items():
parts = key.split(".")
current_obj = obj
for i, part in enumerate(parts):
if i == len(parts) - 1: # Last part is the actual value
if part in field_names:
setattr(current_obj, part, value)
else:
# Navigate or create nested objects
if hasattr(current_obj, part):
next_obj = getattr(current_obj, part)
else:
next_obj_type = field_names.get(part)
next_obj = next_obj_type()
setattr(current_obj, part, next_obj)
current_obj = next_obj
return obj
@dataclass
class ConfigBase:
"""Base class that should handle parsing from command line,
json, dicts.
"""
@classmethod
def parse_from_command_line(cls):
return omegaconf_parse(cls)
@classmethod
def parse_from_file(cls, path: str):
oc = OmegaConf.load(path)
return cls.parse_from_dict(OmegaConf.to_container(oc))
@classmethod
def parse_from_command_line_deprecated(cls):
result = DataclassArgParser(
cls, fromfile_prefix_chars="@"
).parse_args_into_dataclasses()
if len(result) > 1:
raise RuntimeError(
f"The following arguments were not recognized: {result[1:]}"
)
return result[0]
@classmethod
def parse_from_dict(cls, inputs):
return DataclassArgParser._populate_dataclass_from_dict(cls, inputs.copy())
@classmethod
def parse_from_flat_dict(cls, inputs):
return DataclassArgParser._populate_dataclass_from_flat_dict(cls, inputs.copy())
def save(self, path: str):
with open(path, "w") as f:
OmegaConf.save(config=self, f=f)
@dataclass
class VicRegConfig:
lambda_invariance: float = 25.0
mu_variance: float = 25.0
nu_covariance: float = 1.0
# New JEPAConfig class
@dataclass
class JEPAConfig(ConfigBase):
embed_dim: int = 256
wall_embed_dim: int = 128
action_dim: int = 2
action_hidden_dim: int = 32
in_c: int = 2
out_c: int = 64
epochs: int = 20
batch_size: int = 1024
optimizer_type: str = 'adamw'
scheduler_type: str = 'linear'
learning_rate: float = 0.001
model_type: str = 'JEPA'
data_path: str = '/scratch/DL24FA/train'
vicreg_loss: VicRegConfig = field(default_factory=VicRegConfig)
action_reg_hidden_dim: str = '32'
lambda_reg: int = 0.2
delta_gen: int = 1
encoder_backbone: str = "resnet18.a1_in1k"
teacher_forcing: str = 'True'
return_enc: str = 'True'
pred_flattened: str = 'False'
feature_index: int = 1
# You can add more configuration parameters as needed