hezar.models.text_classification.bert.bert_text_classification module

A BERT model for text classification built using HuggingFace Transformers

class hezar.models.text_classification.bert.bert_text_classification.BertTextClassification(config: BertTextClassificationConfig, **kwargs)[source]

Bases: Model

A standard 🤗Transformers Bert model for text classification

Parameters:

config – The whole model config including arguments needed for the inner 🤗Transformers model.

compute_loss(logits: Tensor, labels: Tensor) Tensor[source]

Compute loss on the model outputs against the given labels

Parameters:
  • logits – Logits tensor to compute loss on

  • labels – Labels tensor

Note: Subclasses can also override this method and add other arguments besides logits and labels

Returns:

Loss tensor

forward(token_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, **kwargs) Dict[source]

Forward inputs through the model and return logits, etc.

Parameters:

model_inputs – The required inputs for the model forward

Returns:

A dict of outputs like logits, loss, etc.

post_process(model_outputs: dict, top_k=1)[source]

Process model outputs and return human-readable results. Called in self.predict()

Parameters:
  • model_outputs – model outputs to process

  • **kwargs – extra arguments specific to the derived class

Returns:

Processed model output values and converted to human-readable results

preprocess(inputs: str | List[str], **kwargs)[source]

Given raw inputs, preprocess the inputs and prepare them for model’s forward().

Parameters:
  • raw_inputs – Raw model inputs

  • **kwargs – Extra kwargs specific to the model. See the model’s specific class for more info

Returns:

A dict of inputs for model forward

required_backends: List[Backends | str] = [Backends.TRANSFORMERS, Backends.TOKENIZERS]
skip_keys_on_load = ['model.embeddings.position_ids', 'bert.embeddings.position_ids', 'model.bert.embeddings.position_ids']