hezar.utils.data_utils module¶
- hezar.utils.data_utils.convert_batch_dict_dtype(batch_dict: dict, dtype: str = 'list', skip_keys: list | None = None) dict [source]¶
Convert data dtypes of the values in a batch dict.
- Parameters:
batch_dict (dict) – The batched dictionary. Each key in the dict has a batch of data as its value.
dtype (str) – Target data type to convert to (“list”, “numpy”, “torch”).
skip_keys (list) – A list of key names to skip conversion.
- Returns:
The same dict with cast values.
- Return type:
dict
- hezar.utils.data_utils.dataloader_worker_init_fn(seed)[source]¶
A dataloader worker init function that handles reproducibility by hard-setting the seed for all workers.
- hezar.utils.data_utils.flatten_dict(dict_config: Dict | DictConfig) DictConfig [source]¶
Flatten a nested Dict/DictConfig object
- Parameters:
dict_config – A Dict/DictConfig object
- Returns:
The flattened version of the dict-like object
- hezar.utils.data_utils.get_non_numeric_keys(d: Dict, batched=True)[source]¶
Get keys that have string values in a dictionary
- Parameters:
d – The dict
batched – Are the input dict values batched or not
- Returns:
A list of string-valued keys
- hezar.utils.data_utils.pad_batch_items(inputs: List[List[int | float]], padding: str | PaddingType | None = None, padding_side: Literal['right', 'left'] = 'right', pad_id: int = 0, max_length: bool | int | None = None, truncation: bool | None = True)[source]¶
Given a nested container of unequal sized iterables e.g, batch of token ids, pad them based on padding strategy :param inputs: A nested iterable of unequal sized iterables (e.g, list of lists) :param padding: Padding strategy, either max_length or longest :param padding_side: Where to add padding ids, left or right, defaults to right :param pad_id: Pad token id, defaults to 0 :param max_length: Max input length after padding, only applicable when padding == “max_length” :param truncation: Whether to truncate if an input in the batch is longer than max_length
- Returns:
A list of equal sized lists
- hezar.utils.data_utils.resolve_inputs_length_for_padding(inputs: List[List[Any]], padding: str | PaddingType | None = None, max_length: bool | int | None = None, truncation: bool | None = True)[source]¶
Resolve final inputs length based on padding and max_length values
- hezar.utils.data_utils.set_seed(seed)[source]¶
Set a global seed for all backends to handle reproducibility and determinism.