hezar.models.model module¶
Hezar models inherit the base class Model. A Model itself is a PyTorch Module to implement neural networks but has some extra Hezar-specific functionalities and methods e.g, pushing to hub, loading from hub, etc.
Examples
>>> # Load from hub
>>> from hezar.models import Model
>>> model = Model.load("hezarai/bert-base-fa")
- class hezar.models.model.Model(config: ModelConfig, *args, **kwargs)[source]¶
Bases:
Module
Base class for all neural network models in Hezar.
- Parameters:
config – A dataclass model config
- compute_loss(logits: Tensor, labels: Tensor) Tensor [source]¶
Compute loss on the model outputs against the given labels
- Parameters:
logits – Logits tensor to compute loss on
labels – Labels tensor
Note: Subclasses can also override this method and add other arguments besides logits and labels
- Returns:
Loss tensor
- config_filename = 'model_config.yaml'¶
- property device¶
Get the model’s device. This method is only safe when all weights of the model are on the same device.
- forward(*model_inputs, **kwargs) Dict [source]¶
Forward inputs through the model and return logits, etc.
- Parameters:
model_inputs – The required inputs for the model forward
- Returns:
A dict of outputs like logits, loss, etc.
- generate(*model_inputs, **kwargs) Tensor [source]¶
Generation method for all generative models. Generative models have the is_generative attribute set to True. The behavior of this method is usually controlled by generation part of the model’s config.
- Parameters:
model_inputs – Model inputs for generation, usually the same as forward’s model_inputs
**kwargs – Generation kwargs
- Returns:
Generated output tensor
- is_generative: bool = False¶
- classmethod load(hub_or_local_path: str | PathLike, load_locally: bool | None = False, load_preprocessor: bool | None = True, model_filename: str | None = None, config_filename: str | None = None, save_path: str | PathLike | None = None, cache_dir: str | PathLike | None = None, **kwargs) Model [source]¶
Load the model from local path or hub.
It’s recommended to actually use this method with
hezar.models.Model
rather than any other model class unless you actually know that the class is the same as the one in the config, because the output will always be of the type specified in the config!- Parameters:
hub_or_local_path – Path to the model living on the Hub or local disk.
load_locally – Force loading from local path
load_preprocessor – Whether to load the preprocessor(s) or not
model_filename – Optional model filename.
config_filename – Optional config filename
save_path – Save model to this path after loading
cache_dir – Path to cache directory, defaults to ~/.cache/hezar
- Returns:
The fully loaded Hezar model
- load_state_dict(state_dict: Mapping[str, Any], **kwargs)[source]¶
Flexibly load the state dict to the model.
Any incompatible or missing key is ignored and other layer weights are loaded. In that case a warning with additional info is raised.
- Parameters:
state_dict – Model state dict
- property loss_func¶
- loss_func_kwargs: Dict[str, Any] = {}¶
- model_filename = 'model.pt'¶
- property num_parameters¶
- property num_trainable_parameters¶
- post_process(*model_outputs: Tensor | Any, **kwargs)[source]¶
Process model outputs and return human-readable results. Called in self.predict()
- Parameters:
model_outputs – model outputs to process
**kwargs – extra arguments specific to the derived class
- Returns:
Processed model output values and converted to human-readable results
- predict(inputs: Any | List[Any], device: str | device = None, preprocess: bool = True, unpack_forward_inputs: bool = True, post_process: bool = True, **kwargs) List[Any] | Tensor [source]¶
Perform an end-to-end prediction on raw inputs.
If the model is a generative model, it has to implement the generate() method too which will be called instead of forward(). (forward() method is called internally within the generate() method)
- Parameters:
inputs – Raw inputs e.g, a list of texts, path to images, etc.
device – What device to perform inference on
preprocess – Whether to call :method:`preprocess()` before :method:`forward()`
unpack_forward_inputs – Whether to unpack forward inputs. Set to False if you want to send preprocess outputs directly to the forward/generate method without unpacking it. Note that this only applies to the cases that the preprocess method’s output is a dict-like/mapping object.
post_process – Whether to call :method:`post_process()` after :method:`forward()`
**kwargs – Other arguments for preprocess, forward, generate and post_process. each will be passed to the correct method automatically.
- Returns:
Prediction results, each model or task can have its own type and structure
- preprocess(*raw_inputs: Any | List[Any], **kwargs)[source]¶
Given raw inputs, preprocess the inputs and prepare them for model’s forward().
- Parameters:
raw_inputs – Raw model inputs
**kwargs – Extra kwargs specific to the model. See the model’s specific class for more info
- Returns:
A dict of inputs for model forward
- property preprocessor: PreprocessorsContainer¶
- push_to_hub(repo_id: str, filename: str | None = None, config_filename: str | None = None, push_preprocessor: bool | None = True, commit_message: str | None = None, private: bool | None = False)[source]¶
Push the model and required files to the hub
- Parameters:
repo_id – The path (id or repo name) on the hub
filename – Model file name
config_filename – Config file name
push_preprocessor – Whether to push preprocessor(s) or not
commit_message (str) – Commit message for this push
private (bool) – Whether to create a private repo or not
- save(path: str | PathLike, filename: str | None = None, save_preprocessor: bool | None = True, config_filename: str | None = None)[source]¶
Save model weights and config to a local path
- Parameters:
path – A local directory to save model, config, etc.
save_preprocessor – Whether to save preprocessor(s) along with the model or not
config_filename – Model config filename,
filename – Model weights filename
- Returns:
Path to the saved model
- skip_keys_on_load = []¶