hezar.trainer.trainer_utils module¶
- class hezar.trainer.trainer_utils.AverageMeter(name, avg=None, sum=None, count=None, fmt=':f')[source]¶
Bases:
object
Compute and store the average and current value
- class hezar.trainer.trainer_utils.CSVLogger(logs_dir: str, csv_filename: str)[source]¶
Bases:
object
- class hezar.trainer.trainer_utils.TrainerState(epoch: int = 1, total_epochs: int | None = None, global_step: int = 0, epoch_step: int = 0, loss_tracker_sum: float = 0.0, loss_tracker_avg: float = 0.0, metric_for_best_checkpoint: str | None = None, best_metric_value: float | None = None, best_checkpoint: str | None = None, logs_dir: str | None = None)[source]¶
Bases:
object
A Trainer state is a container for holding specific updating values in the training process and is saved when checkpointing.
- Parameters:
epoch – Current epoch number
total_epochs – Total epochs to train the model
global_step – Number of the update steps so far, one step is a full training step (one batch)
epoch_step – Number of the update steps in the current epoch
loss_tracker_sum – Running sum value of the loss tracker
loss_tracker_avg – Running mean value of the loss tracker
metric_for_best_checkpoint – The metric key for choosing the best checkpoint (Also given in the TrainerConfig)
best_metric_value – The value of the best checkpoint saved so far
best_checkpoint – Path to the best model checkpoint so far
logs_dir – Path to the logs directory
- best_checkpoint: str = None¶
- best_metric_value: float = None¶
- epoch: int = 1¶
- epoch_step: int = 0¶
- global_step: int = 0¶
- logs_dir: str = None¶
- loss_tracker_avg: float = 0.0¶
- loss_tracker_sum: float = 0.0¶
- metric_for_best_checkpoint: str = None¶
- total_epochs: int = None¶