hezar.utils.data_utils module

hezar.utils.data_utils.convert_batch_dict_dtype(batch_dict: dict, dtype: str = 'list', skip_keys: list | None = None) dict[source]

Convert data dtypes of the values in a batch dict.

Parameters:
  • batch_dict (dict) – The batched dictionary. Each key in the dict has a batch of data as its value.

  • dtype (str) – Target data type to convert to (“list”, “numpy”, “torch”).

  • skip_keys (list) – A list of key names to skip conversion.

Returns:

The same dict with cast values.

Return type:

dict

hezar.utils.data_utils.dataloader_worker_init_fn(seed)[source]

A dataloader worker init function that handles reproducibility by hard-setting the seed for all workers.

hezar.utils.data_utils.flatten_dict(dict_config: Dict | DictConfig) DictConfig[source]

Flatten a nested Dict/DictConfig object

Parameters:

dict_config – A Dict/DictConfig object

Returns:

The flattened version of the dict-like object

hezar.utils.data_utils.get_non_numeric_keys(d: Dict, batched=True)[source]

Get keys that have string values in a dictionary

Parameters:
  • d – The dict

  • batched – Are the input dict values batched or not

Returns:

A list of string-valued keys

hezar.utils.data_utils.pad_batch_items(inputs: List[List[int | float]], padding: str | PaddingType | None = None, padding_side: Literal['right', 'left'] = 'right', pad_id: int = 0, max_length: bool | int | None = None, truncation: bool | None = True)[source]

Given a nested container of unequal sized iterables e.g, batch of token ids, pad them based on padding strategy :param inputs: A nested iterable of unequal sized iterables (e.g, list of lists) :param padding: Padding strategy, either max_length or longest :param padding_side: Where to add padding ids, left or right, defaults to right :param pad_id: Pad token id, defaults to 0 :param max_length: Max input length after padding, only applicable when padding == “max_length” :param truncation: Whether to truncate if an input in the batch is longer than max_length

Returns:

A list of equal sized lists

hezar.utils.data_utils.resolve_inputs_length_for_padding(inputs: List[List[Any]], padding: str | PaddingType | None = None, max_length: bool | int | None = None, truncation: bool | None = True)[source]

Resolve final inputs length based on padding and max_length values

hezar.utils.data_utils.set_seed(seed)[source]

Set a global seed for all backends to handle reproducibility and determinism.

hezar.utils.data_utils.shift_tokens_right(token_ids: list[list[int]] | 'torch.Tensor' | 'np.ndarray', pad_token_id: int, decoder_start_token_id: int)[source]

Shift input ids one token to the right.

hezar.utils.data_utils.torch2numpy(*args)[source]

Cast tensors to numpy

Parameters:

*args – Any number of torch.Tensor objects

Returns:

The same inputs cast to numpy