Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Update Speech Recognition example (with learning_rate) #1272

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flash_examples/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h", learning_rate=1e-5)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
predict_datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=predict_datamodule)
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

print("Predictions: ", predictions)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")