hezar.models.text_classification.distilbert.distilbert_text_classification module¶
A DistilBERT model for text classification built using HuggingFace Transformers
- class hezar.models.text_classification.distilbert.distilbert_text_classification.DistilBertTextClassification(config: DistilBertTextClassificationConfig, **kwargs)[source]¶
Bases:
Model
A standard 🤗Transformers DistilBert model for text classification
- Parameters:
config – The whole model config including arguments needed for the inner 🤗Transformers model.
- compute_loss(logits: Tensor, labels: Tensor) Tensor [source]¶
Compute loss on the model outputs against the given labels
- Parameters:
logits – Logits tensor to compute loss on
labels – Labels tensor
Note: Subclasses can also override this method and add other arguments besides logits and labels
- Returns:
Loss tensor
- forward(token_ids, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, **kwargs) Dict [source]¶
Forward inputs through the model and return logits, etc.
- Parameters:
model_inputs – The required inputs for the model forward
- Returns:
A dict of outputs like logits, loss, etc.
- post_process(model_outputs: dict, top_k=1)[source]¶
Process model outputs and return human-readable results. Called in self.predict()
- Parameters:
model_outputs – model outputs to process
**kwargs – extra arguments specific to the derived class
- Returns:
Processed model output values and converted to human-readable results
- preprocess(inputs: str | List[str], **kwargs)[source]¶
Given raw inputs, preprocess the inputs and prepare them for model’s forward().
- Parameters:
raw_inputs – Raw model inputs
**kwargs – Extra kwargs specific to the model. See the model’s specific class for more info
- Returns:
A dict of inputs for model forward