Source code for hezar.preprocessors.tokenizers.tokenizer

from __future__ import annotations

import os
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Tuple

import numpy as np
import torch
from huggingface_hub import create_repo, hf_hub_download, upload_file

from ...builders import build_preprocessor
from ...configs import PreprocessorConfig
from ...constants import (
    DEFAULT_TOKENIZER_CONFIG_FILE,
    DEFAULT_TOKENIZER_FILE,
    HEZAR_CACHE_DIR,
    Backends,
    PaddingType,
)
from ...utils import (
    Logger,
    convert_batch_dict_dtype,
    is_backend_available,
    pad_batch_items,
)
from ..preprocessor import Preprocessor


if is_backend_available(Backends.TOKENIZERS):
    from tokenizers import Tokenizer as HFTokenizer
    from tokenizers.decoders import Decoder
    from tokenizers.models import Model

logger = Logger(__name__)


[docs] @dataclass class TokenizerConfig(PreprocessorConfig): """ Configuration for the Tokenizer. Args: truncation_side (str): Truncation direction for tokenization. stride (int): Stride for tokenization. padding_side (str): Padding direction for tokenization. pad_to_multiple_of (int): Pad to a multiple of this value. pad_token_type_id (int): ID of the padding token type. bos_token (str): Beginning of sequence token. eos_token (str): End of sequence token. unk_token (str): Unknown token. sep_token (str): Separator token. pad_token (str): Padding token. cls_token (str): Classification token. mask_token (str): Mask token. additional_special_tokens (List[str]): Additional special tokens. """ name = "tokenizer" max_length: int = "deprecated" truncation: str = "deprecated" truncation_side: str = None padding: str = "deprecated" padding_side: str = None stride: int = None pad_to_multiple_of: int = "deprecated" pad_token_type_id: int = 0 bos_token: str = None eos_token: str = None unk_token: str = None sep_token: str = None pad_token: str = None cls_token: str = None mask_token: str = None additional_special_tokens: List[str] = None def __post_init__(self): super().__post_init__() if self.max_length != "deprecated": logger.warning( "Setting `max_length` in the tokenizer config is deprecated and will be removed in the future!" ) if self.padding != "deprecated": logger.warning( "Setting `padding` in the tokenizer config is deprecated and will be removed in the future!" ) if self.truncation != "deprecated": logger.warning( "Setting `truncation` in the tokenizer config is deprecated and will be removed in the future!" )
[docs] class Tokenizer(Preprocessor): """ Base tokenizer class. Mostly copied from :class:`~tokenizers.implementations.BaseTokenizer`. Args: config: A TokenizerConfig instance. tokenizer_file (str): A tokenizer.json file to load the whole tokenizer from. **kwargs: Extra config parameters that merge into the main config. """ required_backends: List[str | Backends] = [] tokenizer_filename = DEFAULT_TOKENIZER_FILE tokenizer_config_filename = DEFAULT_TOKENIZER_CONFIG_FILE token_ids_name = "token_ids" uncastable_keys = ["word_ids", "tokens", "offsets_mapping"] def __init__(self, config: TokenizerConfig, tokenizer_file=None, **kwargs): super().__init__(config, **kwargs) self._tokenizer = self.from_file(tokenizer_file) if tokenizer_file is not None else self.build() self.special_tokens = self._get_all_special_tokens() def _get_all_special_tokens(self): """ Get a list of all special tokens. Returns: List[str]: List of special tokens. """ _special_tokens = [ self.config.bos_token, self.config.eos_token, self.config.unk_token, self.config.sep_token, self.config.pad_token, self.config.cls_token, self.config.mask_token, ] _special_tokens = [token for token in _special_tokens if token in self.vocab] if self.config.additional_special_tokens is not None: for token in self.config.additional_special_tokens: if token not in _special_tokens: _special_tokens.append(token) valid_tokens = [token for token in _special_tokens if token is not None] return valid_tokens
[docs] @staticmethod def from_file(path): """ Create a tokenizer from a file. Args: path (str): Path to the tokenizer file. Returns: HFTokenizer: The created tokenizer. """ tokenizer = HFTokenizer.from_file(path) return tokenizer
[docs] def build(self): """ Build the tokenizer. Returns: HFTokenizer: The built tokenizer. """ raise NotImplementedError
[docs] def encode(self, inputs, is_pretokenized: bool = False, add_special_tokens: bool = True, **kwargs): """ Tokenize a list of inputs (could be raw or tokenized inputs). Args: inputs: List of inputs. is_pretokenized: Whether the inputs are already tokenized. add_special_tokens: Whether to add special tokens to the inputs. Defaults to True. **kwargs: Additional keyword arguments. Returns: List[Dict]: List of dictionaries containing tokenized inputs. """ if isinstance(inputs, str): inputs = [inputs] elif isinstance(inputs, list) and is_pretokenized: if isinstance(inputs[0], str): inputs = [inputs] return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens)
[docs] def decode(self, ids: List[int], skip_special_tokens: bool = True, **kwargs): """ Decode a list of token IDs. Args: ids (List[int]): List of token IDs. skip_special_tokens (bool): Whether to skip special tokens during decoding. **kwargs: Additional keyword arguments. Returns: List[str]: List of decoded strings. """ if isinstance(ids[0], int): ids = [ids] if isinstance(ids, (torch.Tensor, np.ndarray)): ids = ids.tolist() return self._tokenizer.decode_batch(ids, skip_special_tokens=skip_special_tokens)
[docs] def pad_encoded_batch( self, inputs, padding: str | PaddingType = None, max_length: Optional[int] = None, truncation: bool = True, return_tensors: Optional[str] = None, include_keys: Optional[List[str]] = None, exclude_keys: List = None, ): """ Pad a batch of encoded inputs. Args: inputs: Input batch of encoded tokens. padding (str | PaddingType): Padding type. max_length (Optional[int]): Max input length (only if padding is set to "max_length"). truncation (bool): Whether to allow truncation. return_tensors (Optional[str]): The type of tensors to return. include_keys: (Optional[List[str]]): Only pad these given set of keys exclude_keys (List): A list of keys to exclude when padding. Returns: Dict: Padded inputs. """ if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], Mapping): inputs = {key: [example[key] for example in inputs] for key in inputs[0].keys()} exclude_keys = exclude_keys or [] exclude_keys += self.uncastable_keys # avoid possible errors inputs = convert_batch_dict_dtype(inputs, dtype="list", skip_keys=exclude_keys) include_keys = include_keys or list(inputs.keys()) for key, batch in inputs.items(): if key in exclude_keys: continue if key in include_keys: pad_id = 0 if key == "attention_mask" else self.pad_token_id padded_ids = pad_batch_items( inputs[key], padding=padding, padding_side=self.config.padding_side, pad_id=pad_id, max_length=max_length, truncation=truncation, ) inputs[key] = padded_ids inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors, skip_keys=exclude_keys) return inputs
def __call__( self, inputs: List[str] | List[Tuple[str, str]], device: str | torch.device = None, add_special_tokens: bool = True, padding=None, truncation=None, max_length: int = None, return_tensors: str = "list", stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: int = None, return_tokens: bool = None, return_token_type_ids: bool = None, return_attention_mask: bool = True, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, return_word_ids: bool = False, verbose: bool = True, **kwargs, ): """ Tokenize a batch of string inputs and return the relevant properties e.g, token ids, attention mask, etc. Args: inputs: A list of string inputs to tokenize add_special_tokens: Whether to add special tokens or not padding: Determines how to pad inputs truncation: Determines how to truncate inputs max_length: Max input length of the sequences return_tensors: The type of the returning tensors in the batch e.g, pt, np, list stride: Stride level is_split_into_words: Are inputs pre-tokenized or raw string inputs pad_to_multiple_of: Pad inputs by a factor of this value return_tokens: Whether to return tokens lists return_token_type_ids: Whether to return token type ids return_attention_mask: Whether to return attention masks return_overflowing_tokens: Whether to return overflowing tokens return_special_tokens_mask: Whether to return special tokens mask return_offsets_mapping: Whether to return offsets return_length: Whether to return input lengths **kwargs: Extra arguments reside here and therefore ignored Returns: A dictionary of encoded inputs like {"token_ids": [batch_size x input_len], "attention_mask": [batch_size x input_len], ...} """ if isinstance(inputs, list) and not len(inputs): raise ValueError("Tokenizer cannot process an empty list!") if "padding_strategy" in kwargs: logger.warning( "`padding_strategy` was deprecated in favor of `padding`!" " This warning will change to an error in the future!" ) if "truncation_strategy" in kwargs: logger.warning( "`truncation_strategy` was deprecated in favor of `truncation`!" " This warning will change to an error in the future!" ) return_tensors = return_tensors or "list" # Convert to batch if input is a single string or a list of words (is split into words for sequence labeling) if isinstance(inputs, str) or (is_split_into_words and not isinstance(inputs[0], list)): inputs = [inputs] is_batch = False else: is_batch = True self.set_truncation_and_padding( padding=padding, truncation=truncation, padding_side=self.config.padding_side, truncation_side=self.config.truncation_side, max_length=max_length, stride=self.config.stride, pad_to_multiple_of=pad_to_multiple_of ) encodings = self.encode( inputs, add_special_tokens=add_special_tokens, is_pretokenized=is_split_into_words, ) encodings_dict = [ self._convert_encodings( encoding=encoding, return_tokens=return_tokens, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, return_word_ids=return_word_ids, ) for encoding in encodings ] # Permute output dict from [batch_0: Dict[key, value], ...] to Dict[key, [batch_0, batch_1, ...], ...] sanitized_outputs = {} for key in encodings_dict[0].keys(): stack = [e for item in encodings_dict for e in item[key]] sanitized_outputs[key] = stack # If returning overflowing tokens, we need to return a mapping # from the batch idx to the original sample if return_overflowing_tokens: overflow_to_sample_mapping = [] for i, encodings_ in enumerate(encodings_dict): overflow_to_sample_mapping += [i] * len(encodings_["input_ids"]) sanitized_outputs["overflow_to_sample_mapping"] = overflow_to_sample_mapping # Squeeze tensor if the original input is a single string and return_tensors is `list` if (return_tensors == "list" or return_tensors is None) and not is_batch: sanitized_outputs = { key: value[0] if len(value) > 0 and isinstance(value[0], list) else value for key, value in sanitized_outputs.items() } outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys) if device and return_tensors == "torch": outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} return outputs
[docs] def set_truncation_and_padding( self, padding=None, truncation=None, padding_side=None, truncation_side=None, max_length: int = None, stride: int = None, pad_to_multiple_of: int = None, ): # Set truncation and padding on the backend tokenizer if truncation == "no_truncation" or truncation is None or max_length is None: if self.truncation is not None: self.no_truncation() else: if truncation is True: truncation = "longest_first" target = { "max_length": max_length, "stride": stride, "strategy": truncation, "direction": truncation_side, } if self.truncation is None: current = None else: current = {k: self.truncation.get(k, None) for k in target} if current != target: self.enable_truncation(**target) if padding == "no_padding" or padding is None: if self.padding is not None: self.no_padding() else: target = { "length": max_length if padding == PaddingType.MAX_LENGTH else None, "direction": padding_side, "pad_id": self.token_to_id(self.pad_token), "pad_token": self.pad_token, "pad_type_id": self.config.pad_token_type_id, "pad_to_multiple_of": pad_to_multiple_of, } if self.padding != target: self.enable_padding(**target)
def _convert_encodings( self, encoding, return_tokens: bool = None, return_token_type_ids: bool = None, return_attention_mask: bool = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, return_word_ids: bool = False, ): if return_overflowing_tokens and encoding.overflowing is not None: encodings = [encoding] + encoding.overflowing else: encodings = [encoding] encoding_dict = defaultdict(list) for e in encodings: encoding_dict[self.token_ids_name].append(e.ids) if return_token_type_ids: encoding_dict["token_type_ids"].append(e.type_ids) if return_attention_mask: encoding_dict["attention_mask"].append(e.attention_mask) if return_special_tokens_mask: encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) if return_offsets_mapping: encoding_dict["offsets_mapping"].append(e.offsets) if return_length: encoding_dict["length"].append(len(e.ids)) if return_tokens: text = self._tokenizer.decode(e.ids) tokens = self.get_tokens_from_offsets(text, e.ids, e.offsets) encoding_dict["tokens"].append(tokens) if return_word_ids: encoding_dict["word_ids"].append(e.word_ids) return encoding_dict
[docs] def convert_tokens_to_ids(self, tokens: str | List[str]) -> int | List[int]: if isinstance(tokens, str): tokens = [tokens] return [self._tokenizer.token_to_id(token) for token in tokens]
[docs] def convert_ids_to_tokens(self, ids: int | List[int], skip_special_tokens: bool = False): if isinstance(ids, int): ids = [ids] tokens = [] for index in ids: index = int(index) if skip_special_tokens and index in self.special_ids: continue tokens.append(self._tokenizer.id_to_token(index)) return tokens
[docs] def num_special_tokens_to_add(self, is_pair: bool) -> int: return self._tokenizer.num_special_tokens_to_add(is_pair)
[docs] def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]: return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
[docs] def get_vocab_size(self, with_added_tokens: bool = True) -> int: return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)
[docs] def enable_padding( self, direction: str = "right", pad_to_multiple_of: int = None, pad_id: int = 0, pad_type_id: int = 0, pad_token: str = None, length: int = None, ): return self._tokenizer.enable_padding( direction=direction, pad_to_multiple_of=pad_to_multiple_of, pad_id=pad_id, pad_type_id=pad_type_id, pad_token=pad_token, length=length, )
[docs] def no_padding(self): return self._tokenizer.no_padding()
[docs] def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"): return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy, direction=direction)
[docs] def no_truncation(self): return self._tokenizer.no_truncation()
[docs] def add_tokens(self, tokens) -> int: return self._tokenizer.add_tokens(tokens)
[docs] def add_special_tokens(self, special_tokens) -> int: return self._tokenizer.add_special_tokens(special_tokens)
[docs] def token_to_id(self, token: str) -> int: return self._tokenizer.token_to_id(token)
[docs] def id_to_token(self, id: int) -> str: return self._tokenizer.id_to_token(id)
[docs] def get_added_vocab(self) -> Dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Returns: `Dict[str, int]`: The added tokens. """ base_vocab = self._tokenizer.get_vocab(with_added_tokens=False) full_vocab = self._tokenizer.get_vocab(with_added_tokens=True) added_vocab = {token: index for token, index in full_vocab.items() if token not in base_vocab} return added_vocab
def __len__(self) -> int: """ Size of the full vocabulary with the added tokens. """ return self._tokenizer.get_vocab_size(with_added_tokens=True)
[docs] def get_tokens_from_offsets( self, text: str | List[str], ids: List[int], offsets_mapping: List[Tuple[int, int]], ): """ Extract human-readable tokens using the original text and offsets mapping Args: text: Raw string text ids: Token ids offsets_mapping: A list of tuples representing offsets Returns: A list of tokens """ if not isinstance(text, str): raise ValueError(f"Expected str type for `text`, got `{type(text)}({text})`") if isinstance(offsets_mapping, list) and not isinstance(offsets_mapping[0], Tuple): raise ValueError(f"Expected a list of tuples for `offsets_mapping`, got List[{type(offsets_mapping[0])}]") tokens = [] for offset in offsets_mapping: offset_start, offset_end = offset tokens.append(text[offset_start:offset_end]) for i, token in enumerate(tokens): if ids[i] in self.special_ids: tokens[i] = self._tokenizer.id_to_token(ids[i]) return tokens
[docs] @classmethod def load( cls, hub_or_local_path, subfolder=None, config_filename=None, tokenizer_filename=None, cache_dir=None, **kwargs, ) -> "Tokenizer": """ Load a tokenizer from a specified path or Hub repository. Args: cls: Class reference. hub_or_local_path: Path or Hub repository ID. subfolder: Subfolder containing tokenizer files. config_filename: Tokenizer config filename. tokenizer_filename: Tokenizer filename. cache_dir: Path to cache directory **kwargs: Additional arguments. Returns: Tokenizer: Loaded tokenizer. """ tokenizer_filename = tokenizer_filename or cls.tokenizer_filename config_filename = config_filename or cls.tokenizer_config_filename subfolder = subfolder or cls.preprocessor_subfolder cache_dir = cache_dir or HEZAR_CACHE_DIR config = TokenizerConfig.load( hub_or_local_path, filename=config_filename, subfolder=subfolder, cache_dir=cache_dir, ) if os.path.isdir(hub_or_local_path): tokenizer_path = os.path.join(hub_or_local_path, subfolder, tokenizer_filename) else: tokenizer_path = hf_hub_download( hub_or_local_path, filename=tokenizer_filename, subfolder=subfolder, cache_dir=cache_dir, ) tokenizer = build_preprocessor(config.name, config, tokenizer_file=tokenizer_path, **kwargs) return tokenizer
[docs] def save(self, path, save_config=True, pretty=True): """ Save the tokenizer and its configuration. Args: path (str): Path to save the tokenizer. save_config (bool): Whether to save the configuration. pretty (bool): Whether to format the saved JSON file with indentation. """ os.makedirs(path, exist_ok=True) # save config if save_config: self.config.vocab_size = self.get_vocab_size(with_added_tokens=True) self.config.save(path, filename=self.tokenizer_config_filename, subfolder=self.preprocessor_subfolder) # save tokenizer.json save_path = os.path.join(path, self.preprocessor_subfolder, self.tokenizer_filename) self._tokenizer.save(save_path, pretty=pretty)
[docs] def push_to_hub( self, repo_id, commit_message=None, subfolder=None, tokenizer_filename=None, config_filename=None, private=False, ): """ Push tokenizer and config to the Hub Args: repo_id: The path (id or repo name) on the hub commit_message: Commit message for this push subfolder: subfolder to save the files tokenizer_filename: tokenizer filename config_filename: tokenizer config filename private: If the repo should be private (ignored if the repo exists) """ subfolder = subfolder or self.preprocessor_subfolder tokenizer_filename = tokenizer_filename or self.tokenizer_filename config_filename = config_filename or self.tokenizer_config_filename # create remote repo create_repo(repo_id, exist_ok=True, private=private) # save to tmp and prepare for push cache_path = tempfile.mkdtemp() # save tokenizer.json tokenizer_save_path = os.path.join(cache_path, subfolder, tokenizer_filename) self.save(cache_path, pretty=True) if commit_message is None: commit_message = "Hezar: Upload tokenizer and config" # upload config self.config.push_to_hub( repo_id=repo_id, filename=config_filename, subfolder=subfolder, commit_message=commit_message, ) # upload tokenizer upload_file( repo_id=repo_id, path_or_fileobj=tokenizer_save_path, repo_type="model", path_in_repo=f"{subfolder}/{tokenizer_filename}", commit_message=commit_message, ) logger.log_upload_success( name=f"{self.__class__.__name__}(name={self.config.name})", target_path=os.path.join(repo_id, subfolder, tokenizer_filename), )
@property def model(self) -> "Model": return self._tokenizer.model @model.setter def model(self, model: "Model"): self._tokenizer.model = model # noqa @property def decoder(self) -> "Decoder": return self._tokenizer.decoder @decoder.setter def decoder(self, decoder: "Decoder"): self._tokenizer.decoder = decoder # noqa @property def padding(self): return self._tokenizer.padding @property def truncation(self) -> dict: return self._tokenizer.truncation @property def vocab(self): return self._tokenizer.get_vocab(with_added_tokens=True) @property def vocab_size(self) -> int: """ `int`: Size of the base vocabulary (without the added tokens). """ return self._tokenizer.get_vocab_size(with_added_tokens=False) @property def special_ids(self): return [self.token_to_id(t) for t in self.special_tokens] @property def pad_token(self): return self.config.pad_token @property def bos_token(self): return self.config.bos_token @property def eos_token(self): return self.config.eos_token @property def unk_token(self): return self.config.unk_token @property def mask_token(self): return self.config.mask_token @property def cls_token(self): return self.config.cls_token @property def sep_token(self): return self.config.sep_token @property def pad_token_id(self): return self.token_to_id(self.config.pad_token) @property def bos_token_id(self): return self.token_to_id(self.config.bos_token) @property def eos_token_id(self): return self.token_to_id(self.config.eos_token) @property def unk_token_id(self): return self.token_to_id(self.config.unk_token) @property def mask_token_id(self): return self.token_to_id(self.config.mask_token) @property def cls_token_id(self): return self.token_to_id(self.config.cls_token) @property def sep_token_id(self): return self.token_to_id(self.config.sep_token)