Skip to content

Commit

Permalink
Format with Black
Browse files Browse the repository at this point in the history
  • Loading branch information
palonso committed Jul 12, 2023
1 parent 852ebfe commit 4a06cfe
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 147 deletions.
56 changes: 15 additions & 41 deletions config_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def mini_train():
"limit training/validation to 5 batches for debbuging."
trainer = dict(limit_train_batches=5, limit_val_batches=5)

datamodule = {"groundtruth_val": "discogs/gt_val_all_400l_super_clean.pk"}

# Experiments from
# EFFICIENT SUPERVISED TRAINING OF AUDIO TRANSFORMERS FOR MUSIC REPRESENTATION LEARNING

Expand All @@ -21,9 +23,7 @@ def mini_train():
def maest_10s_random_weights_pretrain():
"time encodings for up to 10 seconds, and random initialization"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_s_swa_p16_128_ap476",
Expand All @@ -36,9 +36,7 @@ def maest_10s_random_weights_pretrain():
def maest_10s_from_deit_pretrain():
"time encodings for up to 10 seconds and initializaiton to the DeiT weights"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_deit_bd_p16_384",
Expand All @@ -51,9 +49,7 @@ def maest_10s_from_deit_pretrain():
def maest_10s_from_passt_pretrain():
"time encodings for up to 10 seconds and initializaiton to the PaSST weights"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_s_swa_p16_128_ap476",
Expand All @@ -66,9 +62,7 @@ def maest_10s_from_passt_pretrain():
def maest_10s_random_weights_inference():
"time encodings for up to 10 seconds, and random initialization"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_s_swa_p16_128_ap476",
Expand All @@ -77,17 +71,13 @@ def maest_10s_random_weights_inference():
"s_patchout_t": 30,
}

inference = {
"n_block": 11
}
inference = {"n_block": 11}

@ex.named_config
def maest_10s_from_deit_inference():
"time encodings for up to 10 seconds and initializaiton to the DeiT weights"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_deit_bd_p16_384",
Expand All @@ -96,17 +86,13 @@ def maest_10s_from_deit_inference():
"s_patchout_t": 30,
}

inference = {
"n_block": 11
}
inference = {"n_block": 11}

@ex.named_config
def maest_10s_from_passt_inference():
"time encodings for up to 10 seconds and initializaiton to the PaSST weights"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_s_swa_p16_128_ap476",
Expand All @@ -115,9 +101,7 @@ def maest_10s_from_passt_inference():
"s_patchout_t": 30,
}

predict = {
"transformer_block": 7
}
predict = {"transformer_block": 7}

# Section 4.3. Effect of the input sequence length
##################################################
Expand All @@ -126,9 +110,7 @@ def maest_10s_from_passt_inference():
def passt_discogs_5sec():
"time encodings for up to 5 seconds"

datamodule = {
"clip_length": 5
}
datamodule = {"clip_length": 5}

net = {
"arch": "passt_s_swa_p16_128_ap476",
Expand All @@ -141,9 +123,7 @@ def passt_discogs_5sec():
def passt_discogs_10sec():
"time encodings for up to 10 seconds"

datamodule = {
"clip_length": 10
}
datamodule = {"clip_length": 10}

net = {
"arch": "passt_s_swa_p16_128_ap476_discogs",
Expand All @@ -155,9 +135,7 @@ def passt_discogs_10sec():
def passt_discogs_20sec():
"time encodings for up to 20 seconds"

datamodule = {
"clip_length": 20
}
datamodule = {"clip_length": 20}

net = {
"arch": "passt_s_swa_p16_128_ap476_discogs",
Expand All @@ -169,17 +147,14 @@ def passt_discogs_20sec():
def passt_discogs_30sec():
"time encodings for up to 30 seconds"

datamodule = {
"clip_length": 30
}
datamodule = {"clip_length": 30}

net = {
"arch": "passt_s_swa_p16_128_ap476_discogs",
"input_t": 30 * 16000 // 256,
"s_patchout_t": 90,
}


@ex.named_config
def passt_discogs_10sec_fe():
"use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 10 seconds with frequency-wise embeddings"
Expand Down Expand Up @@ -213,7 +188,6 @@ def teacher_target():
teacher_target_threshold=0.45,
)


@ex.named_config
def passt_discogs5sec_inference():
"extracting embeddings with the 5secs model"
Expand Down
10 changes: 7 additions & 3 deletions discogs/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __iter__(self):
indices = list(self.sampler)
if self.epoch == 0:
_logger.info(f"DistributedSamplerWrapper: {indices[:10]}")
indices = indices[self.rank: self.total_size: self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
return iter(indices)


Expand Down Expand Up @@ -185,7 +185,9 @@ def get_ft_weighted_sampler(
_logger.info(f"WORLD_SIZE: {world_size}")
_logger.info(f"LOCAL_RANK: {local_rank}")

sample_weights = self.get_ft_cls_balanced_sample_weights(groundtruth=groundtruth)
sample_weights = self.get_ft_cls_balanced_sample_weights(
groundtruth=groundtruth
)

return DistributedSamplerWrapper(
sampler=WeightedRandomSampler(
Expand Down Expand Up @@ -268,7 +270,9 @@ def test_dataloader(
num_workers,
norm,
):
ds = DiscogsDatasetExhaustive(groundtruth_test, base_dir, clip_length=clip_length)
ds = DiscogsDatasetExhaustive(
groundtruth_test, base_dir, clip_length=clip_length
)

if norm["do"]:
ds = PreprocessDataset(ds, self.get_norm_func())
Expand Down
37 changes: 15 additions & 22 deletions discogs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def __del__(self):
self.dataset_file = None

def __getitem__(self, index):
"""Load waveform and target of an audio clip.
"""
"""Load waveform and target of an audio clip."""

filename = self.filenames[index]
target = self.groundtruth[filename].astype("float16")
Expand Down Expand Up @@ -185,8 +184,7 @@ def __init__(
self.teacher_target_threshold = teacher_target_threshold

def __getitem__(self, index):
"""Load waveform and target of an audio clip.
"""
"""Load waveform and target of an audio clip."""

filename = self.filenames[index]
target = self.groundtruth[filename].astype("float16")
Expand All @@ -202,14 +200,12 @@ def __getitem__(self, index):
teacher_target = expit(teacher_target)

# hard teacher target with a threshold of 0.45
hard_teacher_target = (
teacher_target > self.teacher_target_threshold
).astype("float16")
hard_teacher_target = (teacher_target > self.teacher_target_threshold).astype(
"float16"
)
# if no class is activated, set the highest activation
if not np.sum(hard_teacher_target):
hard_teacher_target = np.zeros(
hard_teacher_target.shape, dtype="float16"
)
hard_teacher_target = np.zeros(hard_teacher_target.shape, dtype="float16")
hard_teacher_target[np.argmax(teacher_target)] = 1.0

return melspectrogram, str(filename), target, hard_teacher_target
Expand Down Expand Up @@ -239,7 +235,9 @@ def __init__(
n_bands=n_bands,
)
self.hop_size = (
self.melspectrogram_size // 2 if half_overlapped_inference else self.melspectrogram_size
self.melspectrogram_size // 2
if half_overlapped_inference
else self.melspectrogram_size
)
self.half_overlap = half_overlapped_inference

Expand All @@ -266,8 +264,7 @@ def __init__(
self.length = len(self.filenames_with_patch)

def __getitem__(self, index):
"""Load waveform and target of an audio clip.
"""
"""Load waveform and target of an audio clip."""

filename, offset = self.filenames_with_patch[index]
target = self.groundtruth[filename].astype("float16")
Expand Down Expand Up @@ -311,8 +308,7 @@ def __init__(
self.teacher_target_threshold = teacher_target_threshold

def __getitem__(self, index):
"""Load waveform and target of an audio clip.
"""
"""Load waveform and target of an audio clip."""

filename, offset = self.filenames_with_patch[index]
target = self.groundtruth[filename].astype("float16")
Expand All @@ -329,14 +325,12 @@ def __getitem__(self, index):
teacher_target = expit(teacher_target)

# hard teacher target with a threshold of 0.45
hard_teacher_target = (
teacher_target > self.teacher_target_threshold
).astype("float16")
hard_teacher_target = (teacher_target > self.teacher_target_threshold).astype(
"float16"
)
# if no class is activated, set the highest activation
if not np.sum(hard_teacher_target):
hard_teacher_target = np.zeros(
hard_teacher_target.shape, dtype="float16"
)
hard_teacher_target = np.zeros(hard_teacher_target.shape, dtype="float16")
hard_teacher_target[np.argmax(teacher_target)] = 1.0

return melspectrogram, str(filename), target, hard_teacher_target
Expand All @@ -346,4 +340,3 @@ def __getitem__(self, index):
def print_conf(_config):
_logger.info(f"Config of {dataset_ing.path}, with id: {id(dataset_ing)}")
_logger.info(_config)

26 changes: 15 additions & 11 deletions ex_discogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
datamodule_ing,
)

ex = Experiment("ex", ingredients=[
dataset_ing,
datamodule_ing,
net_ingredient,
maest_ing,
])
ex.observers.append(FileStorageObserver('exp_logs'))
ex = Experiment(
"ex",
ingredients=[
dataset_ing,
datamodule_ing,
net_ingredient,
maest_ing,
],
)
ex.observers.append(FileStorageObserver("exp_logs"))
_logger = logging.getLogger("ex_discogs")


Expand All @@ -43,7 +46,7 @@ def default_conf():

trainer = {
"max_epochs": 130,
"devices": 1,
"devices": 4,
"sync_batchnorm": True,
"precision": "16-mixed",
"limit_train_batches": None,
Expand All @@ -52,6 +55,7 @@ def default_conf():
"log_every_n_steps": 50,
"reload_dataloaders_every_n_epochs": 1,
"strategy": "ddp_find_unused_parameters_true",
"default_root_dir": "exp_logs",
}

predict = {
Expand All @@ -72,14 +76,12 @@ def main(_run, _config, _log, _rnd, _seed):
module = MAEST()
data = DiscogsDataModule()


trainer.fit(module, data)
return {"done": True}


@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):

modul = MAEST()
modul = modul.cuda()
batch_size = speed_test_batch_size
Expand Down Expand Up @@ -164,7 +166,9 @@ def predict(_run, _config, _log, _rnd, _seed, output_name=""):
for po_type in ("indices", "interleaved"):
if _config["net"][f"s_patchout_{po_dim}_{po_type}"]:
removed_bands = "_".join(
np.array(_config["net"][f"s_patchout_{po_dim}_{po_type}"]).astype("str")
np.array(_config["net"][f"s_patchout_{po_dim}_{po_type}"]).astype(
"str"
)
)
subdir2 += f"_patchout_{po_dim}_{po_type}" + removed_bands

Expand Down
6 changes: 4 additions & 2 deletions helpers/swa_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"):
)

def setup(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] =
None
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
stage: Optional[str] = None,
):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module.net)
Expand Down
Loading

0 comments on commit 4a06cfe

Please sign in to comment.