hezar.trainer.metrics_handlers module

class hezar.trainer.metrics_handlers.Image2TextMetricHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]

Bases: MetricsHandler

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = [MetricType.CER, MetricType.WER]
class hezar.trainer.metrics_handlers.MetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None, **kwargs)[source]

Bases: object

Base metrics handler class for computing metrics. Subclasses must implement compute_metrics method based on their specific task.

Parameters:
  • metrics – A list of metrics (metric raw name or Metric object)

  • model_config – Optional model config

  • trainer_config – Optional trainer config

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = []
class hezar.trainer.metrics_handlers.SequenceLabelingMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]

Bases: MetricsHandler

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = [MetricType.SEQEVAL]
class hezar.trainer.metrics_handlers.SpeechRecognitionMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]

Bases: MetricsHandler

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = [MetricType.CER, MetricType.WER]
class hezar.trainer.metrics_handlers.TextClassificationMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]

Bases: MetricsHandler

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = [MetricType.ACCURACY, MetricType.RECALL, MetricType.PRECISION, MetricType.F1]
class hezar.trainer.metrics_handlers.TextGenerationMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]

Bases: MetricsHandler

compute_metrics(predictions, labels, **kwargs)[source]

Given a batch of predictions and a batch of labels, compute all metrics

Parameters:
  • predictions – Predictions batch usually containing logits

  • labels – Ground truth labels batch

valid_metrics: List[MetricType] = [MetricType.ROUGE, MetricType.BLEU]