Speech Recognition¶
In this tutorial, we’ll finetune the Whisper model on the Persian portion of Common Voice.
Note that this model is still big and requires at least 12 GB of VRAM to train.
Let’s import the required stuff:
from hezar.models import Model
from hezar.data import Dataset
from hezar.trainer import Trainer, TrainerConfig
Define the base model path:
base_model_path = "hezarai/whisper-small"
Dataset¶
As mentioned, we’ll use the CommonVoice (Persian samples) dataset which is provided in Hezar’s Hugging Face.
dataset_path = "hezarai/common-voice-13-fa"
train_dataset = Dataset.load(dataset_path, preprocessor=base_model_path, split="train", labels_max_length=64)
eval_dataset = Dataset.load(dataset_path, preprocessor=base_model_path, split="test", labels_max_length=64)
Model¶
We’ll load the model (with its preprocessors) from the base model’s path.
model = Model.load(base_model_path)
Training¶
Let’s configure the trainer using the TrainerConfig
:
train_config = TrainerConfig(
output_dir="whisper-small-fa-commonvoice",
task="speech_recognition",
mixed_precision="bf16",
resume_from_checkpoint=True,
gradient_accumulation_steps=8,
batch_size=4,
log_steps=100,
save_steps=1000,
num_epochs=5,
metrics=["cer", "wer"],
)
Since the model is big and larger batch sizes might lead to GPU OOM, we define a
gradient_accumulation_steps
of 8.To reduce memory usage, we set the mixed precision to BFloat16 (
bf16
)Saving in between steps is recommended for easier training resumption.
The training loss moving average is logged every 100 steps. (saved to Tensorboard)
Let’s create the Trainer and start it!
trainer = Trainer(
config=train_config,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
Push to Hub¶
If you’d like, you can push the model along with other Trainer files to the Hub.
trainer.push_to_hub("<path/to/model>", commit_message="Upload an awesome ASR model!")