hezar.models.model module

Hezar models inherit the base class Model. A Model itself is a PyTorch Module to implement neural networks but has some extra Hezar-specific functionalities and methods e.g, pushing to hub, loading from hub, etc.

Examples

>>> # Load from hub
>>> from hezar.models import Model
>>> model = Model.load("hezarai/bert-base-fa")
class hezar.models.model.Model(config: ModelConfig, *args, **kwargs)[source]

Bases: Module

Base class for all neural network models in Hezar.

Parameters:

config – A dataclass model config

compute_loss(logits: Tensor, labels: Tensor) Tensor[source]

Compute loss on the model outputs against the given labels

Parameters:
  • logits – Logits tensor to compute loss on

  • labels – Labels tensor

Note: Subclasses can also override this method and add other arguments besides logits and labels

Returns:

Loss tensor

config_filename = 'model_config.yaml'
property device

Get the model’s device. This method is only safe when all weights of the model are on the same device.

forward(*model_inputs, **kwargs) Dict[source]

Forward inputs through the model and return logits, etc.

Parameters:

model_inputs – The required inputs for the model forward

Returns:

A dict of outputs like logits, loss, etc.

generate(*model_inputs, **kwargs) Tensor[source]

Generation method for all generative models. Generative models have the is_generative attribute set to True. The behavior of this method is usually controlled by generation part of the model’s config.

Parameters:
  • model_inputs – Model inputs for generation, usually the same as forward’s model_inputs

  • **kwargs – Generation kwargs

Returns:

Generated output tensor

is_generative: bool = False
classmethod load(hub_or_local_path: str | PathLike, load_locally: bool | None = False, load_preprocessor: bool | None = True, model_filename: str | None = None, config_filename: str | None = None, save_path: str | PathLike | None = None, cache_dir: str | PathLike | None = None, **kwargs) Model[source]

Load the model from local path or hub.

It’s recommended to actually use this method with hezar.models.Model rather than any other model class unless you actually know that the class is the same as the one in the config, because the output will always be of the type specified in the config!

Parameters:
  • hub_or_local_path – Path to the model living on the Hub or local disk.

  • load_locally – Force loading from local path

  • load_preprocessor – Whether to load the preprocessor(s) or not

  • model_filename – Optional model filename.

  • config_filename – Optional config filename

  • save_path – Save model to this path after loading

  • cache_dir – Path to cache directory, defaults to ~/.cache/hezar

Returns:

The fully loaded Hezar model

load_state_dict(state_dict: Mapping[str, Any], **kwargs)[source]

Flexibly load the state dict to the model.

Any incompatible or missing key is ignored and other layer weights are loaded. In that case a warning with additional info is raised.

Parameters:

state_dict – Model state dict

property loss_func
loss_func_kwargs: Dict[str, Any] = {}
loss_func_name: str | LossType = 'cross_entropy'
model_filename = 'model.pt'
property num_parameters
property num_trainable_parameters
post_process(*model_outputs: Tensor | Any, **kwargs)[source]

Process model outputs and return human-readable results. Called in self.predict()

Parameters:
  • model_outputs – model outputs to process

  • **kwargs – extra arguments specific to the derived class

Returns:

Processed model output values and converted to human-readable results

predict(inputs: Any | List[Any], device: str | device = None, preprocess: bool = True, unpack_forward_inputs: bool = True, post_process: bool = True, **kwargs) List[Any] | Tensor[source]

Perform an end-to-end prediction on raw inputs.

If the model is a generative model, it has to implement the generate() method too which will be called instead of forward(). (forward() method is called internally within the generate() method)

Parameters:
  • inputs – Raw inputs e.g, a list of texts, path to images, etc.

  • device – What device to perform inference on

  • preprocess – Whether to call :method:`preprocess()` before :method:`forward()`

  • unpack_forward_inputs – Whether to unpack forward inputs. Set to False if you want to send preprocess outputs directly to the forward/generate method without unpacking it. Note that this only applies to the cases that the preprocess method’s output is a dict-like/mapping object.

  • post_process – Whether to call :method:`post_process()` after :method:`forward()`

  • **kwargs – Other arguments for preprocess, forward, generate and post_process. each will be passed to the correct method automatically.

Returns:

Prediction results, each model or task can have its own type and structure

preprocess(*raw_inputs: Any | List[Any], **kwargs)[source]

Given raw inputs, preprocess the inputs and prepare them for model’s forward().

Parameters:
  • raw_inputs – Raw model inputs

  • **kwargs – Extra kwargs specific to the model. See the model’s specific class for more info

Returns:

A dict of inputs for model forward

property preprocessor: PreprocessorsContainer
push_to_hub(repo_id: str, filename: str | None = None, config_filename: str | None = None, push_preprocessor: bool | None = True, commit_message: str | None = None, private: bool | None = False)[source]

Push the model and required files to the hub

Parameters:
  • repo_id – The path (id or repo name) on the hub

  • filename – Model file name

  • config_filename – Config file name

  • push_preprocessor – Whether to push preprocessor(s) or not

  • commit_message (str) – Commit message for this push

  • private (bool) – Whether to create a private repo or not

required_backends: List[Backends | str] = []
save(path: str | PathLike, filename: str | None = None, save_preprocessor: bool | None = True, config_filename: str | None = None)[source]

Save model weights and config to a local path

Parameters:
  • path – A local directory to save model, config, etc.

  • save_preprocessor – Whether to save preprocessor(s) along with the model or not

  • config_filename – Model config filename,

  • filename – Model weights filename

Returns:

Path to the saved model

skip_keys_on_load = []