Source code for hezar.preprocessors.tokenizers.tokenizer

from __future__ import annotations

import os
import tempfile
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Optional, Self

import numpy as np
import torch
from huggingface_hub import create_repo, hf_hub_download, upload_file

from ...builders import build_preprocessor
from ...configs import PreprocessorConfig
from ...constants import (
    DEFAULT_TOKENIZER_CONFIG_FILE,
    DEFAULT_TOKENIZER_FILE,
    HEZAR_CACHE_DIR,
    Backends,
    PaddingType,
)
from ...utils import (
    Logger,
    convert_batch_dict_dtype,
    is_backend_available,
    pad_batch_items,
)
from ..preprocessor import Preprocessor


if is_backend_available(Backends.TOKENIZERS):
    from tokenizers import Tokenizer as HFTokenizer
    from tokenizers.decoders import Decoder
    from tokenizers.models import Model

logger = Logger(__name__)


[docs] @dataclass class TokenizerConfig(PreprocessorConfig): """ Configuration for the Tokenizer. Args: truncation_side (str): Truncation direction for tokenization. stride (int): Stride for tokenization. padding_side (str): Padding direction for tokenization. pad_to_multiple_of (int): Pad to a multiple of this value. pad_token_type_id (int): ID of the padding token type. bos_token (str): Beginning of sequence token. eos_token (str): End of sequence token. unk_token (str): Unknown token. sep_token (str): Separator token. pad_token (str): Padding token. cls_token (str): Classification token. mask_token (str): Mask token. additional_special_tokens (List[str]): Additional special tokens. """ name = "tokenizer" max_length: int | str = "deprecated" truncation: str | bool = "deprecated" truncation_side: str | None = None padding: str | bool = "deprecated" padding_side: str | None = None stride: int | None = None pad_to_multiple_of: int | str = "deprecated" pad_token_type_id: int = 0 bos_token: str | None = None eos_token: str | None = None unk_token: str | None = None sep_token: str | None = None pad_token: str | None = None cls_token: str | None = None mask_token: str | None = None additional_special_tokens: list[str] | None = None def __post_init__(self): super().__post_init__() if self.max_length != "deprecated": logger.warning( "Setting `max_length` in the tokenizer config is deprecated and will be removed in the future!" ) if self.padding != "deprecated": logger.warning("Setting `padding` in the tokenizer config is deprecated and will be removed in the future!") if self.truncation != "deprecated": logger.warning( "Setting `truncation` in the tokenizer config is deprecated and will be removed in the future!" )
[docs] class Tokenizer(Preprocessor): """ Base tokenizer class. Mostly copied from :class:`~tokenizers.implementations.BaseTokenizer`. Args: config: A TokenizerConfig instance. tokenizer_file (str): A tokenizer.json file to load the whole tokenizer from. **kwargs: Extra config parameters that merge into the main config. """ required_backends: list[str | Backends] = [] tokenizer_filename = DEFAULT_TOKENIZER_FILE tokenizer_config_filename = DEFAULT_TOKENIZER_CONFIG_FILE token_ids_name = "token_ids" uncastable_keys = ["word_ids", "tokens", "offsets_mapping"] def __init__(self, config: TokenizerConfig, tokenizer_file=None, **kwargs): super().__init__(config, **kwargs) self._tokenizer = self.from_file(tokenizer_file) if tokenizer_file is not None else self.build() self.special_tokens = self._get_all_special_tokens() def _get_all_special_tokens(self): """ Get a list of all special tokens. Returns: List[str]: List of special tokens. """ _special_tokens = [ self.config.bos_token, self.config.eos_token, self.config.unk_token, self.config.sep_token, self.config.pad_token, self.config.cls_token, self.config.mask_token, ] _special_tokens = [token for token in _special_tokens if token in self.vocab] if self.config.additional_special_tokens is not None: for token in self.config.additional_special_tokens: if token not in _special_tokens: _special_tokens.append(token) valid_tokens = [token for token in _special_tokens if token is not None] return valid_tokens
[docs] @staticmethod def from_file(path): """ Create a tokenizer from a file. Args: path (str): Path to the tokenizer file. Returns: HFTokenizer: The created tokenizer. """ tokenizer = HFTokenizer.from_file(path) return tokenizer
[docs] def build(self): """ Build the tokenizer. Returns: HFTokenizer: The built tokenizer. """ raise NotImplementedError
[docs] def encode(self, inputs, is_pretokenized: bool = False, add_special_tokens: bool = True, **kwargs): """ Tokenize a list of inputs (could be raw or tokenized inputs). Args: inputs: List of inputs. is_pretokenized: Whether the inputs are already tokenized. add_special_tokens: Whether to add special tokens to the inputs. Defaults to True. **kwargs: Additional keyword arguments. Returns: List[Dict]: List of dictionaries containing tokenized inputs. """ if isinstance(inputs, str): inputs = [inputs] elif isinstance(inputs, list) and is_pretokenized: if isinstance(inputs[0], str): inputs = [inputs] return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens)
[docs] def decode(self, ids: list[int] | list[list[int]], skip_special_tokens: bool = True, **kwargs): """ Decode a list of token IDs. Args: ids (List[int] | List[List[int]]): List of token IDs. skip_special_tokens (bool): Whether to skip special tokens during decoding. **kwargs: Additional keyword arguments. Returns: List[str]: List of decoded strings. """ if isinstance(ids[0], int): ids = [ids] # ty:ignore if isinstance(ids, (torch.Tensor, np.ndarray)): ids = ids.tolist() # ty:ignore return self._tokenizer.decode_batch(ids, skip_special_tokens=skip_special_tokens)
[docs] def pad_encoded_batch( self, inputs, padding: str | PaddingType | None = None, max_length: Optional[int] = None, truncation: bool = True, return_tensors: str | None = None, include_keys: Optional[list[str]] = None, exclude_keys: list | None = None, ): """ Pad a batch of encoded inputs. Args: inputs: Input batch of encoded tokens. padding (str | PaddingType): Padding type. max_length (Optional[int]): Max input length (only if padding is set to "max_length"). truncation (bool): Whether to allow truncation. return_tensors (str | None): The type of tensors to return. include_keys: (Optional[List[str]]): Only pad these given set of keys exclude_keys (List): A list of keys to exclude when padding. Returns: Dict: Padded inputs. """ if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], Mapping): inputs = {key: [example[key] for example in inputs] for key in inputs[0].keys()} exclude_keys = exclude_keys or [] exclude_keys += self.uncastable_keys # avoid possible errors inputs = convert_batch_dict_dtype(inputs, dtype="list", skip_keys=exclude_keys) include_keys = include_keys or list(inputs.keys()) for key, _batch in inputs.items(): if key in exclude_keys: continue if key in include_keys: pad_id = 0 if key == "attention_mask" else self.pad_token_id padded_ids = pad_batch_items( inputs[key], padding=padding, padding_side=self.config.padding_side, pad_id=pad_id, max_length=max_length, truncation=truncation, ) inputs[key] = padded_ids inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors, skip_keys=exclude_keys) return inputs
def __call__( self, inputs: list[str] | list[tuple[str, str]], device: str | torch.device | None = None, add_special_tokens: bool = True, padding=None, truncation=None, max_length: int | None = None, return_tensors: str = "list", stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: int | None = None, return_tokens: bool | None = None, return_token_type_ids: bool | None = None, return_attention_mask: bool = True, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, return_word_ids: bool = False, verbose: bool = True, **kwargs, ): """ Tokenize a batch of string inputs and return the relevant properties e.g, token ids, attention mask, etc. Args: inputs: A list of string inputs to tokenize add_special_tokens: Whether to add special tokens or not padding: Determines how to pad inputs truncation: Determines how to truncate inputs max_length: Max input length of the sequences return_tensors: The type of the returning tensors in the batch e.g, pt, np, list stride: Stride level is_split_into_words: Are inputs pre-tokenized or raw string inputs pad_to_multiple_of: Pad inputs by a factor of this value return_tokens: Whether to return tokens lists return_token_type_ids: Whether to return token type ids return_attention_mask: Whether to return attention masks return_overflowing_tokens: Whether to return overflowing tokens return_special_tokens_mask: Whether to return special tokens mask return_offsets_mapping: Whether to return offsets return_length: Whether to return input lengths **kwargs: Extra arguments reside here and therefore ignored Returns: A dictionary of encoded inputs like {"token_ids": [batch_size x input_len], "attention_mask": [batch_size x input_len], ...} """ if isinstance(inputs, list) and not len(inputs): raise ValueError("Tokenizer cannot process an empty list!") if "padding_strategy" in kwargs: logger.warning( "`padding_strategy` was deprecated in favor of `padding`!" " This warning will change to an error in the future!" ) if "truncation_strategy" in kwargs: logger.warning( "`truncation_strategy` was deprecated in favor of `truncation`!" " This warning will change to an error in the future!" ) return_tensors = return_tensors or "list" # Convert to batch if input is a single string or a list of words (is split into words for sequence labeling) if isinstance(inputs, str) or (is_split_into_words and not isinstance(inputs[0], list)): inputs = [inputs] # type: ignore is_batch = False else: is_batch = True self.set_truncation_and_padding( padding=padding, truncation=truncation, padding_side=self.config.padding_side, truncation_side=self.config.truncation_side, max_length=max_length, stride=self.config.stride, pad_to_multiple_of=pad_to_multiple_of, ) encodings = self.encode( inputs, add_special_tokens=add_special_tokens, is_pretokenized=is_split_into_words, ) encodings_dict = [ self._convert_encodings( encoding=encoding, original_text=inp if isinstance(inp, str) else None, return_tokens=return_tokens, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, return_word_ids=return_word_ids, ) for encoding, inp in zip(encodings, inputs) ] # Permute output dict from [batch_0: Dict[key, value], ...] to Dict[key, [batch_0, batch_1, ...], ...] sanitized_outputs = {} for key in encodings_dict[0].keys(): stack = [e for item in encodings_dict for e in item[key]] sanitized_outputs[key] = stack # If returning overflowing tokens, we need to return a mapping # from the batch idx to the original sample if return_overflowing_tokens: overflow_to_sample_mapping = [] for i, encodings_ in enumerate(encodings_dict): overflow_to_sample_mapping += [i] * len(encodings_["input_ids"]) sanitized_outputs["overflow_to_sample_mapping"] = overflow_to_sample_mapping # Squeeze tensor if the original input is a single string and return_tensors is `list` if (return_tensors == "list" or return_tensors is None) and not is_batch: sanitized_outputs = { key: value[0] if len(value) > 0 and isinstance(value[0], list) else value for key, value in sanitized_outputs.items() } outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys) if device and return_tensors == "torch": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} return outputs
[docs] def set_truncation_and_padding( self, padding=None, truncation=None, padding_side=None, truncation_side=None, max_length: int | None = None, stride: int | None = None, pad_to_multiple_of: int | None = None, ): # Set truncation and padding on the backend tokenizer if truncation == "no_truncation" or truncation is None or max_length is None: if self.truncation is not None: self.no_truncation() else: if truncation is True: truncation = "longest_first" target = { "max_length": max_length, "stride": stride, "strategy": truncation, "direction": truncation_side, } if self.truncation is None: current = None else: current = {k: self.truncation.get(k, None) for k in target} if current != target: self.enable_truncation(**target) if padding == "no_padding" or padding is None: if self.padding is not None: self.no_padding() else: target = { "length": max_length if padding == PaddingType.MAX_LENGTH else None, "direction": padding_side, "pad_id": self.token_to_id(self.pad_token), "pad_token": self.pad_token, "pad_type_id": self.config.pad_token_type_id, "pad_to_multiple_of": pad_to_multiple_of, } if self.padding != target: self.enable_padding(**target) # ty:ignore
def _convert_encodings( self, encoding, original_text: str | None = None, return_tokens: bool | None = None, return_token_type_ids: bool | None = None, return_attention_mask: bool | None = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, return_word_ids: bool = False, ): if return_overflowing_tokens and encoding.overflowing is not None: encodings = [encoding] + encoding.overflowing else: encodings = [encoding] encoding_dict = defaultdict(list) for e in encodings: encoding_dict[self.token_ids_name].append(e.ids) if return_token_type_ids: encoding_dict["token_type_ids"].append(e.type_ids) if return_attention_mask: encoding_dict["attention_mask"].append(e.attention_mask) if return_special_tokens_mask: encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) if return_offsets_mapping: encoding_dict["offsets_mapping"].append(e.offsets) if return_length: encoding_dict["length"].append(len(e.ids)) if return_tokens: tokens = self.get_tokens_from_encoding(e, original_text=original_text) encoding_dict["tokens"].append(tokens) if return_word_ids: encoding_dict["word_ids"].append(e.word_ids) return encoding_dict
[docs] def convert_tokens_to_ids(self, tokens: str | list[str]) -> list[int]: if isinstance(tokens, str): tokens = [tokens] return [self._tokenizer.token_to_id(token) for token in tokens]
[docs] def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False): if isinstance(ids, int): ids = [ids] tokens = [] for index in ids: index = int(index) if skip_special_tokens and index in self.special_ids: continue tokens.append(self._tokenizer.id_to_token(index)) return tokens
[docs] def num_special_tokens_to_add(self, is_pair: bool) -> int: return self._tokenizer.num_special_tokens_to_add(is_pair)
[docs] def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]: return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
[docs] def get_vocab_size(self, with_added_tokens: bool = True) -> int: return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)
[docs] def enable_padding( self, direction: str = "right", pad_to_multiple_of: int | None = None, pad_id: int = 0, pad_type_id: int = 0, pad_token: str | None = None, length: int | None = None, ): return self._tokenizer.enable_padding( direction=direction, pad_to_multiple_of=pad_to_multiple_of, pad_id=pad_id, pad_type_id=pad_type_id, pad_token=pad_token, length=length, )
[docs] def no_padding(self): return self._tokenizer.no_padding()
[docs] def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"): return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy, direction=direction)
[docs] def no_truncation(self): return self._tokenizer.no_truncation()
[docs] def add_tokens(self, tokens) -> int: return self._tokenizer.add_tokens(tokens)
[docs] def add_special_tokens(self, special_tokens) -> int: return self._tokenizer.add_special_tokens(special_tokens)
[docs] def token_to_id(self, token: str) -> int: return self._tokenizer.token_to_id(token)
[docs] def id_to_token(self, id: int) -> str: return self._tokenizer.id_to_token(id)
[docs] def get_added_vocab(self) -> dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Returns: `Dict[str, int]`: The added tokens. """ base_vocab = self._tokenizer.get_vocab(with_added_tokens=False) full_vocab = self._tokenizer.get_vocab(with_added_tokens=True) added_vocab = {token: index for token, index in full_vocab.items() if token not in base_vocab} return added_vocab
def __len__(self) -> int: """ Size of the full vocabulary with the added tokens. """ return self._tokenizer.get_vocab_size(with_added_tokens=True) def _is_byte_level(self) -> bool: """ Return True if the underlying tokenizer uses a byte-level decoder (e.g. BPE). Byte-level tokenizers encode each byte of a UTF-8 character as a separate symbol, so ``encoding.tokens`` contains unreadable byte-level strings like ``'اÙĪ'`` instead of the original Persian word ``'او'``. In that case we must reconstruct token strings by slicing the original input text with the offset mapping instead. """ if self._tokenizer.decoder is None: return False return type(self._tokenizer.decoder).__name__ == "ByteLevel"
[docs] def get_tokens_from_encoding(self, encoding, original_text: str | None = None): """ Extract human-readable tokens from the HF ``Encoding`` object. Two strategies are used depending on the tokenizer type: * **Non-byte-level** (WordPiece, Unigram, …): ``encoding.tokens`` already holds the correct human-readable subword strings (e.g. ``'##زم'``), so we return them directly. * **Byte-level BPE**: ``encoding.tokens`` contains unreadable byte-encoded strings. Here we slice ``original_text`` with ``encoding.offsets``, which always map back to the original input correctly. .. note:: Never use ``tokenizer.decode(ids)`` + offsets here. The decoder rewrites the text (removes zero-width non-joiners, re-spaces punctuation, …) so the character positions in the offsets no longer align. Args: encoding: An ``Encoding`` object returned by the HF tokenizers backend. original_text: The original input string, required for byte-level BPE. Returns: List[str]: A list of human-readable token strings. """ if self._is_byte_level(): if original_text is None: # Fallback: return raw tokens if original text is unavailable return list(encoding.tokens) return self.get_tokens_from_offsets(original_text, encoding.ids, encoding.offsets) return list(encoding.tokens)
[docs] def get_tokens_from_offsets( self, text: str, ids: list[int], offsets_mapping: list[tuple[int, int]], ): """ Extract human-readable tokens by slicing *text* with the provided offsets. .. warning:: *text* **must** be the same string that was originally fed to the tokenizer (or its normalised form if a normalizer is configured). Never pass the output of ``tokenizer.decode(ids)`` here: the decoder rewrites the text (removes zero-width characters, re-spaces punctuation, …) so the character positions in *offsets_mapping* will no longer align correctly, resulting in garbled tokens. Args: text: The original (or normalised) input string. ids: Token ids corresponding to the offsets. offsets_mapping: A list of (start, end) character offset tuples into *text*. Returns: List[str]: A list of token strings. """ if not isinstance(text, str): raise ValueError(f"Expected str type for `text`, got `{type(text)}({text})`") if isinstance(offsets_mapping, list) and not isinstance(offsets_mapping[0], tuple): raise ValueError(f"Expected a list of tuples for `offsets_mapping`, got List[{type(offsets_mapping[0])}]") tokens = [] for offset in offsets_mapping: offset_start, offset_end = offset tokens.append(text[offset_start:offset_end]) for i, _token in enumerate(tokens): if ids[i] in self.special_ids: tokens[i] = self._tokenizer.id_to_token(ids[i]) return tokens
[docs] @classmethod def load( cls, hub_or_local_path, subfolder=None, cache_dir=None, config_filename=None, tokenizer_filename=None, force_return_dict: bool = False, **kwargs, ) -> Self: """ Load a tokenizer from a specified path or Hub repository. Args: cls: Class reference. hub_or_local_path: Path or Hub repository ID. subfolder: Subfolder containing tokenizer files. config_filename: Tokenizer config filename. tokenizer_filename: Tokenizer filename. cache_dir: Path to cache directory **kwargs: Additional arguments. Returns: Tokenizer: Loaded tokenizer. """ tokenizer_filename = tokenizer_filename or cls.tokenizer_filename config_filename = config_filename or cls.tokenizer_config_filename subfolder = subfolder or cls.preprocessor_subfolder cache_dir = cache_dir or HEZAR_CACHE_DIR config = TokenizerConfig.load( hub_or_local_path, filename=config_filename, subfolder=subfolder, cache_dir=cache_dir, ) if os.path.isdir(hub_or_local_path): tokenizer_path = os.path.join(hub_or_local_path, subfolder, tokenizer_filename) else: tokenizer_path = hf_hub_download( hub_or_local_path, filename=tokenizer_filename, subfolder=subfolder, cache_dir=cache_dir, ) tokenizer = build_preprocessor(config.name, config, tokenizer_file=tokenizer_path, **kwargs) return tokenizer
[docs] def save(self, path, save_config=True, pretty=True, **kwargs): """ Save the tokenizer and its configuration. Args: path (str): Path to save the tokenizer. save_config (bool): Whether to save the configuration. pretty (bool): Whether to format the saved JSON file with indentation. """ os.makedirs(path, exist_ok=True) # save config if save_config: self.config.vocab_size = self.get_vocab_size(with_added_tokens=True) self.config.save(path, filename=self.tokenizer_config_filename, subfolder=self.preprocessor_subfolder) # save tokenizer.json save_path = os.path.join(path, self.preprocessor_subfolder, self.tokenizer_filename) self._tokenizer.save(save_path, pretty=pretty)
[docs] def push_to_hub( self, repo_id, subfolder=None, commit_message=None, private=False, config_filename=None, tokenizer_filename=None, **kwargs, ): """ Push tokenizer and config to the Hub Args: repo_id: The path (id or repo name) on the hub commit_message: Commit message for this push subfolder: subfolder to save the files tokenizer_filename: tokenizer filename config_filename: tokenizer config filename private: If the repo should be private (ignored if the repo exists) """ subfolder = subfolder or self.preprocessor_subfolder tokenizer_filename = tokenizer_filename or self.tokenizer_filename config_filename = config_filename or self.tokenizer_config_filename # create remote repo create_repo(repo_id, exist_ok=True, private=private) # save to tmp and prepare for push cache_path = tempfile.mkdtemp() # save tokenizer.json tokenizer_save_path = os.path.join(cache_path, subfolder, tokenizer_filename) self.save(cache_path, pretty=True) if commit_message is None: commit_message = "Hezar: Upload tokenizer and config" # upload config self.config.push_to_hub( repo_id=repo_id, filename=config_filename, subfolder=subfolder, commit_message=commit_message, ) # upload tokenizer upload_file( repo_id=repo_id, path_or_fileobj=tokenizer_save_path, repo_type="model", path_in_repo=f"{subfolder}/{tokenizer_filename}", commit_message=commit_message, ) logger.log_upload_success( name=f"{self.__class__.__name__}(name={self.config.name})", target_path=os.path.join(repo_id, subfolder, tokenizer_filename), )
@property def model(self) -> "Model": return self._tokenizer.model @model.setter def model(self, model: "Model"): self._tokenizer.model = model # noqa @property def decoder(self) -> "Decoder": return self._tokenizer.decoder @decoder.setter def decoder(self, decoder: "Decoder"): self._tokenizer.decoder = decoder # noqa @property def padding(self): return self._tokenizer.padding @property def truncation(self) -> dict: return self._tokenizer.truncation @property def vocab(self): return self._tokenizer.get_vocab(with_added_tokens=True) @property def vocab_size(self) -> int: """ `int`: Size of the base vocabulary (without the added tokens). """ return self._tokenizer.get_vocab_size(with_added_tokens=False) @property def special_ids(self): return [self.token_to_id(t) for t in self.special_tokens] @property def pad_token(self): return self.config.pad_token @property def bos_token(self): return self.config.bos_token @property def eos_token(self): return self.config.eos_token @property def unk_token(self): return self.config.unk_token @property def mask_token(self): return self.config.mask_token @property def cls_token(self): return self.config.cls_token @property def sep_token(self): return self.config.sep_token @property def pad_token_id(self): return self.token_to_id(self.config.pad_token) @property def bos_token_id(self): return self.token_to_id(self.config.bos_token) @property def eos_token_id(self): return self.token_to_id(self.config.eos_token) @property def unk_token_id(self): return self.token_to_id(self.config.unk_token) @property def mask_token_id(self): return self.token_to_id(self.config.mask_token) @property def cls_token_id(self): return self.token_to_id(self.config.cls_token) @property def sep_token_id(self): return self.token_to_id(self.config.sep_token)