diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py index b3fc8eba10..5cdb120285 100644 --- a/flash_examples/speech_recognition.py +++ b/flash_examples/speech_recognition.py @@ -29,7 +29,7 @@ ) # 2. Build the task -model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h") +model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h", learning_rate=1e-5) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) @@ -38,7 +38,8 @@ # 4. Predict on audio files! datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4) predictions = trainer.predict(model, datamodule=datamodule) -print(predictions) + +print("Predictions: ", predictions) # 5. Save the model! trainer.save_checkpoint("speech_recognition_model.pt")