Training & Fine-tuning¶
Training a model in Hezar is pretty much like any other library or even simpler! As mentioned before, any model in Hezar is also a PyTorch module. So training a model is actually training a PyTorch model with some more cool features! Let’s dive in.
Setup¶
In this example we’re going to train a sentiment analysis model based on DistilBERT on a dataset containing text and sentiment pairs collected from SnappFood/Digikala user comments.
Import everything needed¶
First things first, let’s import the required stuff.
from hezar.models import DistilBertTextClassification, DistilBertTextClassificationConfig
from hezar.data import Dataset
from hezar.trainer import Trainer, TrainerConfig
from hezar.preprocessors import Preprocessor
Define paths¶
Let’s define our paths to the datasets, tokenizer, etc.
DATASET_PATH = "hezarai/sentiment-dksf" # dataset path on the Hub
BASE_MODEL_PATH = "hezarai/distilbert-base-fa" # used as model backbone weights and tokenizer
Datasets¶
We can easily load our desired datasets from the Hub.
train_dataset = Dataset.load(DATASET_PATH, split="train", tokenizer_path=BASE_MODEL_PATH)
eval_dataset = Dataset.load(DATASET_PATH, split="test", tokenizer_path=BASE_MODEL_PATH)
Model¶
Let’s build our model along with its tokenizer.
Build the model¶
model = DistilBertTextClassification(DistilBertTextClassificationConfig(id2label=train_dataset.config.id2label))
Load the tokenizer¶
The tokenizer can be loaded from the base model path.
tokenizer = Preprocessor.load(BASE_MODEL_PATH)
Trainer¶
Hezar has a general Trainer class that satisfies most of your needs. You can customize almost every single part of it
but for now, we stick with the base class Trainer
.
Trainer Config¶
Define all the training properties in the trainer’s config. As we’re training a text classification model we set the
task to text_classification
in our config. Other parameters are also customizable like below:
train_config = TrainerConfig(
output_dir="distilbert-fa-sentiment-analysis-dksf",
task="text_classification",
device="cuda",
init_weights_from=BASE_MODEL_PATH,
batch_size=8,
num_epochs=5,
metrics=["f1"],
num_dataloader_workers=0,
seed=42,
optimizer="adamw",
learning_rate=2e-5,
weight_decay=.0,
scheduler="reduce_on_plateau",
use_amp=False,
save_freq=1,
)
Setup the Trainer¶
Now that we have our training config we can setup the Trainer.
trainer = Trainer(
config=train_config,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=train_dataset.data_collator,
preprocessor=tokenizer,
)
Start Training¶
trainer.train()
Epoch: 1/5 100%|####################################| 3576/3576 [07:07<00:00, 8.37batch/s, f1=0.732, loss=0.619]
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 38.64batch/s, f1=0.8, loss=0.473]
Epoch: 2/5 100%|####################################| 3576/3576 [07:00<00:00, 8.50batch/s, f1=0.807, loss=0.47]
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 39.87batch/s, f1=0.838, loss=0.419]
Epoch: 3/5 100%|####################################| 3576/3576 [07:01<00:00, 8.48batch/s, f1=0.864, loss=0.348]
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 39.97batch/s, f1=0.875, loss=0.346]
Epoch: 4/5 100%|####################################| 3576/3576 [06:57<00:00, 8.56batch/s, f1=0.919, loss=0.227]
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 38.84batch/s, f1=0.875, loss=0.381]
Epoch: 5/5 100%|####################################| 3576/3576 [07:02<00:00, 8.46batch/s, f1=0.943, loss=0.156]
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 39.71batch/s, f1=0.887, loss=0.446]
Evaluate¶
trainer.evaluate()
Evaluating... 100%|####################################| 290/290 [00:07<00:00, 39.46batch/s, f1=0.887, loss=0.445]
Push everything¶
Now you can push your trained model to the Hub. The files to push are the model, model config, preprocessor, trainer config, etc.
trainer.push_to_hub("arxyzan/distilbert-fa-sentiment-dksf")
Advanced concepts¶
You can also explore the in-depth Trainer guide here.