hezar.utils.data_utils module¶
- hezar.utils.data_utils.convert_batch_dict_dtype(batch_dict: dict, dtype: str = 'list', skip_keys: list | None = None) dict[source]¶
- Convert data dtypes of the values in a batch dict. - Parameters:
- batch_dict (dict) – The batched dictionary. Each key in the dict has a batch of data as its value. 
- dtype (str) – Target data type to convert to (“list”, “numpy”, “torch”). 
- skip_keys (list) – A list of key names to skip conversion. 
 
- Returns:
- The same dict with cast values. 
- Return type:
- dict 
 
- hezar.utils.data_utils.dataloader_worker_init_fn(seed)[source]¶
- A dataloader worker init function that handles reproducibility by hard-setting the seed for all workers. 
- hezar.utils.data_utils.flatten_dict(dict_config: Dict | DictConfig) DictConfig[source]¶
- Flatten a nested Dict/DictConfig object - Parameters:
- dict_config – A Dict/DictConfig object 
- Returns:
- The flattened version of the dict-like object 
 
- hezar.utils.data_utils.get_non_numeric_keys(d: Dict, batched=True)[source]¶
- Get keys that have string values in a dictionary - Parameters:
- d – The dict 
- batched – Are the input dict values batched or not 
 
- Returns:
- A list of string-valued keys 
 
- hezar.utils.data_utils.pad_batch_items(inputs: List[List[int | float]], padding: str | PaddingType | None = None, padding_side: Literal['right', 'left'] = 'right', pad_id: int = 0, max_length: bool | int | None = None, truncation: bool | None = True)[source]¶
- Given a nested container of unequal sized iterables e.g, batch of token ids, pad them based on padding strategy :param inputs: A nested iterable of unequal sized iterables (e.g, list of lists) :param padding: Padding strategy, either max_length or longest :param padding_side: Where to add padding ids, left or right, defaults to right :param pad_id: Pad token id, defaults to 0 :param max_length: Max input length after padding, only applicable when padding == “max_length” :param truncation: Whether to truncate if an input in the batch is longer than max_length - Returns:
- A list of equal sized lists 
 
- hezar.utils.data_utils.resolve_inputs_length_for_padding(inputs: List[List[Any]], padding: str | PaddingType | None = None, max_length: bool | int | None = None, truncation: bool | None = True)[source]¶
- Resolve final inputs length based on padding and max_length values 
- hezar.utils.data_utils.set_seed(seed)[source]¶
- Set a global seed for all backends to handle reproducibility and determinism.