hezar.models.image2text.crnn.crnn_image2text module

class hezar.models.image2text.crnn.crnn_image2text.CRNNImage2Text(config: CRNNImage2TextConfig, **kwargs)[source]

Bases: Model

A robust CRNN model for character level OCR based on the original paper.

compute_loss(logits: Tensor, labels: Tensor)[source]

Compute loss on the model outputs against the given labels

Parameters:
  • logits – Logits tensor to compute loss on

  • labels – Labels tensor

Note: Subclasses can also override this method and add other arguments besides logits and labels

Returns:

Loss tensor

forward(pixel_values, **kwargs)[source]

Forward inputs through the model and return logits, etc.

Parameters:

model_inputs – The required inputs for the model forward

Returns:

A dict of outputs like logits, loss, etc.

generate(pixel_values)[source]

Generation method for all generative models. Generative models have the is_generative attribute set to True. The behavior of this method is usually controlled by generation part of the model’s config.

Parameters:
  • model_inputs – Model inputs for generation, usually the same as forward’s model_inputs

  • **kwargs – Generation kwargs

Returns:

Generated output tensor

image_processor = 'image_processor'
is_generative: bool = True
loss_func_kwargs: Dict[str, Any] = {'zero_infinity': True}
loss_func_name: str | LossType = 'ctc'
post_process(generation_outputs, return_scores=False)[source]

Process model outputs and return human-readable results. Called in self.predict()

Parameters:
  • model_outputs – model outputs to process

  • **kwargs – extra arguments specific to the derived class

Returns:

Processed model output values and converted to human-readable results

preprocess(inputs, **kwargs)[source]

Given raw inputs, preprocess the inputs and prepare them for model’s forward().

Parameters:
  • raw_inputs – Raw model inputs

  • **kwargs – Extra kwargs specific to the model. See the model’s specific class for more info

Returns:

A dict of inputs for model forward

class hezar.models.image2text.crnn.crnn_image2text.ConvBlock(input_channel, output_channel, kernel_sizes, strides, paddings, batch_norm: bool = False)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.