hezar.trainer.trainer_utils module

class hezar.trainer.trainer_utils.AverageMeter(name, avg=None, sum=None, count=None, fmt=':f')[source]

Bases: object

Compute and store the average and current value

reset()[source]
update(val, n=1)[source]
class hezar.trainer.trainer_utils.CSVLogger(logs_dir: str, csv_filename: str)[source]

Bases: object

write(logs: dict, step: int)[source]
class hezar.trainer.trainer_utils.MetricsTracker(metrics)[source]

Bases: object

avg()[source]
reset()[source]
update(results)[source]
class hezar.trainer.trainer_utils.TrainerState(epoch: int = 1, total_epochs: int | None = None, global_step: int = 0, epoch_step: int = 0, loss_tracker_sum: float = 0.0, loss_tracker_avg: float = 0.0, metric_for_best_checkpoint: str | None = None, best_metric_value: float | None = None, best_checkpoint: str | None = None, logs_dir: str | None = None)[source]

Bases: object

A Trainer state is a container for holding specific updating values in the training process and is saved when checkpointing.

Parameters:
  • epoch – Current epoch number

  • total_epochs – Total epochs to train the model

  • global_step – Number of the update steps so far, one step is a full training step (one batch)

  • epoch_step – Number of the update steps in the current epoch

  • loss_tracker_sum – Running sum value of the loss tracker

  • loss_tracker_avg – Running mean value of the loss tracker

  • metric_for_best_checkpoint – The metric key for choosing the best checkpoint (Also given in the TrainerConfig)

  • best_metric_value – The value of the best checkpoint saved so far

  • best_checkpoint – Path to the best model checkpoint so far

  • logs_dir – Path to the logs directory

best_checkpoint: str = None
best_metric_value: float = None
epoch: int = 1
epoch_step: int = 0
global_step: int = 0
classmethod load(path)[source]

Load a trainer state from path

logs_dir: str = None
loss_tracker_avg: float = 0.0
loss_tracker_sum: float = 0.0
metric_for_best_checkpoint: str = None
save(path, drop_none: bool = False)[source]

Save the state to a .yaml file at path

total_epochs: int = None
update(items: dict, **kwargs)[source]
update_best_results(metric_value, objective, step)[source]
hezar.trainer.trainer_utils.get_distributed_logger(name: str, level: str | None = None, fmt: str | None = None)[source]

Distributed logger is responsible for handling logging on multiple processes/machines

hezar.trainer.trainer_utils.get_lr_scheduler_type(lr_scheduler, schedulers_mapping: dict)[source]
hezar.trainer.trainer_utils.resolve_logdir(log_dir) str[source]
hezar.trainer.trainer_utils.write_to_tensorboard(writer: SummaryWriter, logs: dict, step: int)[source]