hezar.models.speech_recognition.whisper.whisper_speech_recognition module

class hezar.models.speech_recognition.whisper.whisper_speech_recognition.WhisperSpeechRecognition(config: WhisperSpeechRecognitionConfig, **kwargs)[source]

Bases: Model

Whisper model for automatic speech recognition

compute_loss(logits: Tensor, labels: Tensor, attention_mask: Tensor | None = None)[source]

Compute loss on the model outputs against the given labels

Parameters:
  • logits – Logits tensor to compute loss on

  • labels – Labels tensor

Note: Subclasses can also override this method and add other arguments besides logits and labels

Returns:

Loss tensor

forward(input_features, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, **kwargs)[source]

Forward inputs through the model and return logits, etc.

Parameters:

model_inputs – The required inputs for the model forward

Returns:

A dict of outputs like logits, loss, etc.

freeze_encoder()[source]
generate(input_features, attention_mask=None, forced_decoder_ids=None, generation_config=None, logits_processor=None, stopping_criteria=None, prefix_allowed_tokens_fn=None, synced_gpus=None, return_timestamps=None, task=None, language=None, is_multilingual=None, prompt_ids=None, **kwargs)[source]

Generation method for all generative models. Generative models have the is_generative attribute set to True. The behavior of this method is usually controlled by generation part of the model’s config.

Parameters:
  • model_inputs – Model inputs for generation, usually the same as forward’s model_inputs

  • **kwargs – Generation kwargs

Returns:

Generated output tensor

get_decoder()[source]
get_encoder()[source]
get_input_embeddings() Module[source]
get_output_embeddings()[source]
is_generative: bool = True
loss_func_name: str | LossType = 'cross_entropy'
post_process(model_outputs, skip_special_tokens=True, decode_with_timestamps=False, output_offsets=False, **kwargs)[source]

Process model outputs and return human-readable results. Called in self.predict()

Parameters:
  • model_outputs – model outputs to process

  • **kwargs – extra arguments specific to the derived class

Returns:

Processed model output values and converted to human-readable results

prepare_inputs_for_generation(decoder_input_ids, past_key_values=None, use_cache=None, encoder_outputs=None, attention_mask=None, **kwargs)[source]
preprocess(inputs: str | ndarray | List[ndarray] | List[str], language=None, **kwargs)[source]

Given raw inputs, preprocess the inputs and prepare them for model’s forward().

Parameters:
  • raw_inputs – Raw model inputs

  • **kwargs – Extra kwargs specific to the model. See the model’s specific class for more info

Returns:

A dict of inputs for model forward

required_backends: List[Backends | str] = [Backends.TRANSFORMERS, Backends.TOKENIZERS, Backends.LIBROSA]
resize_token_embeddings(new_num_tokens: int) Embedding[source]
set_output_embeddings(new_embeddings)[source]