hezar.trainer.metrics_handlers module¶
- class hezar.trainer.metrics_handlers.Image2TextMetricHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]¶
Bases:
MetricsHandler
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = [MetricType.CER, MetricType.WER]¶
- class hezar.trainer.metrics_handlers.MetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None, **kwargs)[source]¶
Bases:
object
Base metrics handler class for computing metrics. Subclasses must implement compute_metrics method based on their specific task.
- Parameters:
metrics – A list of metrics (metric raw name or Metric object)
model_config – Optional model config
trainer_config – Optional trainer config
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = []¶
- class hezar.trainer.metrics_handlers.SequenceLabelingMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]¶
Bases:
MetricsHandler
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = [MetricType.SEQEVAL]¶
- class hezar.trainer.metrics_handlers.SpeechRecognitionMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]¶
Bases:
MetricsHandler
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = [MetricType.CER, MetricType.WER]¶
- class hezar.trainer.metrics_handlers.TextClassificationMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]¶
Bases:
MetricsHandler
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = [MetricType.ACCURACY, MetricType.RECALL, MetricType.PRECISION, MetricType.F1]¶
- class hezar.trainer.metrics_handlers.TextGenerationMetricsHandler(metrics: List[str | MetricType | Metric | MetricConfig], trainer=None)[source]¶
Bases:
MetricsHandler
- compute_metrics(predictions, labels, **kwargs)[source]¶
Given a batch of predictions and a batch of labels, compute all metrics
- Parameters:
predictions – Predictions batch usually containing logits
labels – Ground truth labels batch
- valid_metrics: List[MetricType] = [MetricType.ROUGE, MetricType.BLEU]¶