from __future__ import annotations
import os
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Tuple
import numpy as np
import torch
from huggingface_hub import create_repo, hf_hub_download, upload_file
from ...builders import build_preprocessor
from ...configs import PreprocessorConfig
from ...constants import (
DEFAULT_TOKENIZER_CONFIG_FILE,
DEFAULT_TOKENIZER_FILE,
HEZAR_CACHE_DIR,
Backends,
PaddingType,
)
from ...utils import (
Logger,
convert_batch_dict_dtype,
is_backend_available,
pad_batch_items,
)
from ..preprocessor import Preprocessor
if is_backend_available(Backends.TOKENIZERS):
from tokenizers import Tokenizer as HFTokenizer
from tokenizers.decoders import Decoder
from tokenizers.models import Model
logger = Logger(__name__)
[docs]
@dataclass
class TokenizerConfig(PreprocessorConfig):
"""
Configuration for the Tokenizer.
Args:
truncation_side (str): Truncation direction for tokenization.
stride (int): Stride for tokenization.
padding_side (str): Padding direction for tokenization.
pad_to_multiple_of (int): Pad to a multiple of this value.
pad_token_type_id (int): ID of the padding token type.
bos_token (str): Beginning of sequence token.
eos_token (str): End of sequence token.
unk_token (str): Unknown token.
sep_token (str): Separator token.
pad_token (str): Padding token.
cls_token (str): Classification token.
mask_token (str): Mask token.
additional_special_tokens (List[str]): Additional special tokens.
"""
name = "tokenizer"
max_length: int = "deprecated"
truncation: str = "deprecated"
truncation_side: str = None
padding: str = "deprecated"
padding_side: str = None
stride: int = None
pad_to_multiple_of: int = "deprecated"
pad_token_type_id: int = 0
bos_token: str = None
eos_token: str = None
unk_token: str = None
sep_token: str = None
pad_token: str = None
cls_token: str = None
mask_token: str = None
additional_special_tokens: List[str] = None
def __post_init__(self):
super().__post_init__()
if self.max_length != "deprecated":
logger.warning(
"Setting `max_length` in the tokenizer config is deprecated and will be removed in the future!"
)
if self.padding != "deprecated":
logger.warning(
"Setting `padding` in the tokenizer config is deprecated and will be removed in the future!"
)
if self.truncation != "deprecated":
logger.warning(
"Setting `truncation` in the tokenizer config is deprecated and will be removed in the future!"
)
[docs]
class Tokenizer(Preprocessor):
"""
Base tokenizer class. Mostly copied from :class:`~tokenizers.implementations.BaseTokenizer`.
Args:
config: A TokenizerConfig instance.
tokenizer_file (str): A tokenizer.json file to load the whole tokenizer from.
**kwargs: Extra config parameters that merge into the main config.
"""
required_backends: List[str | Backends] = []
tokenizer_filename = DEFAULT_TOKENIZER_FILE
tokenizer_config_filename = DEFAULT_TOKENIZER_CONFIG_FILE
token_ids_name = "token_ids"
uncastable_keys = ["word_ids", "tokens", "offsets_mapping"]
def __init__(self, config: TokenizerConfig, tokenizer_file=None, **kwargs):
super().__init__(config, **kwargs)
self._tokenizer = self.from_file(tokenizer_file) if tokenizer_file is not None else self.build()
self.special_tokens = self._get_all_special_tokens()
def _get_all_special_tokens(self):
"""
Get a list of all special tokens.
Returns:
List[str]: List of special tokens.
"""
_special_tokens = [
self.config.bos_token,
self.config.eos_token,
self.config.unk_token,
self.config.sep_token,
self.config.pad_token,
self.config.cls_token,
self.config.mask_token,
]
_special_tokens = [token for token in _special_tokens if token in self.vocab]
if self.config.additional_special_tokens is not None:
for token in self.config.additional_special_tokens:
if token not in _special_tokens:
_special_tokens.append(token)
valid_tokens = [token for token in _special_tokens if token is not None]
return valid_tokens
[docs]
@staticmethod
def from_file(path):
"""
Create a tokenizer from a file.
Args:
path (str): Path to the tokenizer file.
Returns:
HFTokenizer: The created tokenizer.
"""
tokenizer = HFTokenizer.from_file(path)
return tokenizer
[docs]
def build(self):
"""
Build the tokenizer.
Returns:
HFTokenizer: The built tokenizer.
"""
raise NotImplementedError
[docs]
def encode(self, inputs, is_pretokenized: bool = False, add_special_tokens: bool = True, **kwargs):
"""
Tokenize a list of inputs (could be raw or tokenized inputs).
Args:
inputs: List of inputs.
is_pretokenized: Whether the inputs are already tokenized.
add_special_tokens: Whether to add special tokens to the inputs. Defaults to True.
**kwargs: Additional keyword arguments.
Returns:
List[Dict]: List of dictionaries containing tokenized inputs.
"""
if isinstance(inputs, str):
inputs = [inputs]
elif isinstance(inputs, list) and is_pretokenized:
if isinstance(inputs[0], str):
inputs = [inputs]
return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens)
[docs]
def decode(self, ids: List[int], skip_special_tokens: bool = True, **kwargs):
"""
Decode a list of token IDs.
Args:
ids (List[int]): List of token IDs.
skip_special_tokens (bool): Whether to skip special tokens during decoding.
**kwargs: Additional keyword arguments.
Returns:
List[str]: List of decoded strings.
"""
if isinstance(ids[0], int):
ids = [ids]
if isinstance(ids, (torch.Tensor, np.ndarray)):
ids = ids.tolist()
return self._tokenizer.decode_batch(ids, skip_special_tokens=skip_special_tokens)
[docs]
def pad_encoded_batch(
self,
inputs,
padding: str | PaddingType = None,
max_length: Optional[int] = None,
truncation: bool = True,
return_tensors: Optional[str] = None,
include_keys: Optional[List[str]] = None,
exclude_keys: List = None,
):
"""
Pad a batch of encoded inputs.
Args:
inputs: Input batch of encoded tokens.
padding (str | PaddingType): Padding type.
max_length (Optional[int]): Max input length (only if padding is set to "max_length").
truncation (bool): Whether to allow truncation.
return_tensors (Optional[str]): The type of tensors to return.
include_keys: (Optional[List[str]]): Only pad these given set of keys
exclude_keys (List): A list of keys to exclude when padding.
Returns:
Dict: Padded inputs.
"""
if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], Mapping):
inputs = {key: [example[key] for example in inputs] for key in inputs[0].keys()}
exclude_keys = exclude_keys or []
exclude_keys += self.uncastable_keys # avoid possible errors
inputs = convert_batch_dict_dtype(inputs, dtype="list", skip_keys=exclude_keys)
include_keys = include_keys or list(inputs.keys())
for key, batch in inputs.items():
if key in exclude_keys:
continue
if key in include_keys:
pad_id = 0 if key == "attention_mask" else self.pad_token_id
padded_ids = pad_batch_items(
inputs[key],
padding=padding,
padding_side=self.config.padding_side,
pad_id=pad_id,
max_length=max_length,
truncation=truncation,
)
inputs[key] = padded_ids
inputs = convert_batch_dict_dtype(inputs, dtype=return_tensors, skip_keys=exclude_keys)
return inputs
def __call__(
self,
inputs: List[str] | List[Tuple[str, str]],
device: str | torch.device = None,
add_special_tokens: bool = True,
padding=None,
truncation=None,
max_length: int = None,
return_tensors: str = "list",
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: int = None,
return_tokens: bool = None,
return_token_type_ids: bool = None,
return_attention_mask: bool = True,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
return_word_ids: bool = False,
verbose: bool = True,
**kwargs,
):
"""
Tokenize a batch of string inputs and return the relevant properties e.g, token ids, attention mask, etc.
Args:
inputs: A list of string inputs to tokenize
add_special_tokens: Whether to add special tokens or not
padding: Determines how to pad inputs
truncation: Determines how to truncate inputs
max_length: Max input length of the sequences
return_tensors: The type of the returning tensors in the batch e.g, pt, np, list
stride: Stride level
is_split_into_words: Are inputs pre-tokenized or raw string inputs
pad_to_multiple_of: Pad inputs by a factor of this value
return_tokens: Whether to return tokens lists
return_token_type_ids: Whether to return token type ids
return_attention_mask: Whether to return attention masks
return_overflowing_tokens: Whether to return overflowing tokens
return_special_tokens_mask: Whether to return special tokens mask
return_offsets_mapping: Whether to return offsets
return_length: Whether to return input lengths
**kwargs: Extra arguments reside here and therefore ignored
Returns:
A dictionary of encoded inputs like
{"token_ids": [batch_size x input_len], "attention_mask": [batch_size x input_len], ...}
"""
if isinstance(inputs, list) and not len(inputs):
raise ValueError("Tokenizer cannot process an empty list!")
if "padding_strategy" in kwargs:
logger.warning(
"`padding_strategy` was deprecated in favor of `padding`!"
" This warning will change to an error in the future!"
)
if "truncation_strategy" in kwargs:
logger.warning(
"`truncation_strategy` was deprecated in favor of `truncation`!"
" This warning will change to an error in the future!"
)
return_tensors = return_tensors or "list"
# Convert to batch if input is a single string or a list of words (is split into words for sequence labeling)
if isinstance(inputs, str) or (is_split_into_words and not isinstance(inputs[0], list)):
inputs = [inputs]
is_batch = False
else:
is_batch = True
self.set_truncation_and_padding(
padding=padding,
truncation=truncation,
padding_side=self.config.padding_side,
truncation_side=self.config.truncation_side,
max_length=max_length,
stride=self.config.stride,
pad_to_multiple_of=pad_to_multiple_of
)
encodings = self.encode(
inputs,
add_special_tokens=add_special_tokens,
is_pretokenized=is_split_into_words,
)
encodings_dict = [
self._convert_encodings(
encoding=encoding,
return_tokens=return_tokens,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
return_word_ids=return_word_ids,
)
for encoding in encodings
]
# Permute output dict from [batch_0: Dict[key, value], ...] to Dict[key, [batch_0, batch_1, ...], ...]
sanitized_outputs = {}
for key in encodings_dict[0].keys():
stack = [e for item in encodings_dict for e in item[key]]
sanitized_outputs[key] = stack
# If returning overflowing tokens, we need to return a mapping
# from the batch idx to the original sample
if return_overflowing_tokens:
overflow_to_sample_mapping = []
for i, encodings_ in enumerate(encodings_dict):
overflow_to_sample_mapping += [i] * len(encodings_["input_ids"])
sanitized_outputs["overflow_to_sample_mapping"] = overflow_to_sample_mapping
# Squeeze tensor if the original input is a single string and return_tensors is `list`
if (return_tensors == "list" or return_tensors is None) and not is_batch:
sanitized_outputs = {
key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
for key, value in sanitized_outputs.items()
}
outputs = convert_batch_dict_dtype(sanitized_outputs, dtype=return_tensors, skip_keys=self.uncastable_keys)
if device and return_tensors == "torch":
outputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}
return outputs
[docs]
def set_truncation_and_padding(
self,
padding=None,
truncation=None,
padding_side=None,
truncation_side=None,
max_length: int = None,
stride: int = None,
pad_to_multiple_of: int = None,
):
# Set truncation and padding on the backend tokenizer
if truncation == "no_truncation" or truncation is None or max_length is None:
if self.truncation is not None:
self.no_truncation()
else:
if truncation is True:
truncation = "longest_first"
target = {
"max_length": max_length,
"stride": stride,
"strategy": truncation,
"direction": truncation_side,
}
if self.truncation is None:
current = None
else:
current = {k: self.truncation.get(k, None) for k in target}
if current != target:
self.enable_truncation(**target)
if padding == "no_padding" or padding is None:
if self.padding is not None:
self.no_padding()
else:
target = {
"length": max_length if padding == PaddingType.MAX_LENGTH else None,
"direction": padding_side,
"pad_id": self.token_to_id(self.pad_token),
"pad_token": self.pad_token,
"pad_type_id": self.config.pad_token_type_id,
"pad_to_multiple_of": pad_to_multiple_of,
}
if self.padding != target:
self.enable_padding(**target)
def _convert_encodings(
self,
encoding,
return_tokens: bool = None,
return_token_type_ids: bool = None,
return_attention_mask: bool = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
return_word_ids: bool = False,
):
if return_overflowing_tokens and encoding.overflowing is not None:
encodings = [encoding] + encoding.overflowing
else:
encodings = [encoding]
encoding_dict = defaultdict(list)
for e in encodings:
encoding_dict[self.token_ids_name].append(e.ids)
if return_token_type_ids:
encoding_dict["token_type_ids"].append(e.type_ids)
if return_attention_mask:
encoding_dict["attention_mask"].append(e.attention_mask)
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
if return_offsets_mapping:
encoding_dict["offsets_mapping"].append(e.offsets)
if return_length:
encoding_dict["length"].append(len(e.ids))
if return_tokens:
text = self._tokenizer.decode(e.ids)
tokens = self.get_tokens_from_offsets(text, e.ids, e.offsets)
encoding_dict["tokens"].append(tokens)
if return_word_ids:
encoding_dict["word_ids"].append(e.word_ids)
return encoding_dict
[docs]
def convert_tokens_to_ids(self, tokens: str | List[str]) -> int | List[int]:
if isinstance(tokens, str):
tokens = [tokens]
return [self._tokenizer.token_to_id(token) for token in tokens]
[docs]
def convert_ids_to_tokens(self, ids: int | List[int], skip_special_tokens: bool = False):
if isinstance(ids, int):
ids = [ids]
tokens = []
for index in ids:
index = int(index)
if skip_special_tokens and index in self.special_ids:
continue
tokens.append(self._tokenizer.id_to_token(index))
return tokens
[docs]
def num_special_tokens_to_add(self, is_pair: bool) -> int:
return self._tokenizer.num_special_tokens_to_add(is_pair)
[docs]
def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]:
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
[docs]
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)
[docs]
def enable_padding(
self,
direction: str = "right",
pad_to_multiple_of: int = None,
pad_id: int = 0,
pad_type_id: int = 0,
pad_token: str = None,
length: int = None,
):
return self._tokenizer.enable_padding(
direction=direction,
pad_to_multiple_of=pad_to_multiple_of,
pad_id=pad_id,
pad_type_id=pad_type_id,
pad_token=pad_token,
length=length,
)
[docs]
def no_padding(self):
return self._tokenizer.no_padding()
[docs]
def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"):
return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy, direction=direction)
[docs]
def no_truncation(self):
return self._tokenizer.no_truncation()
[docs]
def add_tokens(self, tokens) -> int:
return self._tokenizer.add_tokens(tokens)
[docs]
def add_special_tokens(self, special_tokens) -> int:
return self._tokenizer.add_special_tokens(special_tokens)
[docs]
def token_to_id(self, token: str) -> int:
return self._tokenizer.token_to_id(token)
[docs]
def id_to_token(self, id: int) -> str:
return self._tokenizer.id_to_token(id)
[docs]
def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index.
Returns:
`Dict[str, int]`: The added tokens.
"""
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
added_vocab = {token: index for token, index in full_vocab.items() if token not in base_vocab}
return added_vocab
def __len__(self) -> int:
"""
Size of the full vocabulary with the added tokens.
"""
return self._tokenizer.get_vocab_size(with_added_tokens=True)
[docs]
def get_tokens_from_offsets(
self,
text: str | List[str],
ids: List[int],
offsets_mapping: List[Tuple[int, int]],
):
"""
Extract human-readable tokens using the original text and offsets mapping
Args:
text: Raw string text
ids: Token ids
offsets_mapping: A list of tuples representing offsets
Returns:
A list of tokens
"""
if not isinstance(text, str):
raise ValueError(f"Expected str type for `text`, got `{type(text)}({text})`")
if isinstance(offsets_mapping, list) and not isinstance(offsets_mapping[0], Tuple):
raise ValueError(f"Expected a list of tuples for `offsets_mapping`, got List[{type(offsets_mapping[0])}]")
tokens = []
for offset in offsets_mapping:
offset_start, offset_end = offset
tokens.append(text[offset_start:offset_end])
for i, token in enumerate(tokens):
if ids[i] in self.special_ids:
tokens[i] = self._tokenizer.id_to_token(ids[i])
return tokens
[docs]
@classmethod
def load(
cls,
hub_or_local_path,
subfolder=None,
config_filename=None,
tokenizer_filename=None,
cache_dir=None,
**kwargs,
) -> "Tokenizer":
"""
Load a tokenizer from a specified path or Hub repository.
Args:
cls: Class reference.
hub_or_local_path: Path or Hub repository ID.
subfolder: Subfolder containing tokenizer files.
config_filename: Tokenizer config filename.
tokenizer_filename: Tokenizer filename.
cache_dir: Path to cache directory
**kwargs: Additional arguments.
Returns:
Tokenizer: Loaded tokenizer.
"""
tokenizer_filename = tokenizer_filename or cls.tokenizer_filename
config_filename = config_filename or cls.tokenizer_config_filename
subfolder = subfolder or cls.preprocessor_subfolder
cache_dir = cache_dir or HEZAR_CACHE_DIR
config = TokenizerConfig.load(
hub_or_local_path,
filename=config_filename,
subfolder=subfolder,
cache_dir=cache_dir,
)
if os.path.isdir(hub_or_local_path):
tokenizer_path = os.path.join(hub_or_local_path, subfolder, tokenizer_filename)
else:
tokenizer_path = hf_hub_download(
hub_or_local_path,
filename=tokenizer_filename,
subfolder=subfolder,
cache_dir=cache_dir,
)
tokenizer = build_preprocessor(config.name, config, tokenizer_file=tokenizer_path, **kwargs)
return tokenizer
[docs]
def save(self, path, save_config=True, pretty=True):
"""
Save the tokenizer and its configuration.
Args:
path (str): Path to save the tokenizer.
save_config (bool): Whether to save the configuration.
pretty (bool): Whether to format the saved JSON file with indentation.
"""
os.makedirs(path, exist_ok=True)
# save config
if save_config:
self.config.vocab_size = self.get_vocab_size(with_added_tokens=True)
self.config.save(path, filename=self.tokenizer_config_filename, subfolder=self.preprocessor_subfolder)
# save tokenizer.json
save_path = os.path.join(path, self.preprocessor_subfolder, self.tokenizer_filename)
self._tokenizer.save(save_path, pretty=pretty)
[docs]
def push_to_hub(
self,
repo_id,
commit_message=None,
subfolder=None,
tokenizer_filename=None,
config_filename=None,
private=False,
):
"""
Push tokenizer and config to the Hub
Args:
repo_id: The path (id or repo name) on the hub
commit_message: Commit message for this push
subfolder: subfolder to save the files
tokenizer_filename: tokenizer filename
config_filename: tokenizer config filename
private: If the repo should be private (ignored if the repo exists)
"""
subfolder = subfolder or self.preprocessor_subfolder
tokenizer_filename = tokenizer_filename or self.tokenizer_filename
config_filename = config_filename or self.tokenizer_config_filename
# create remote repo
create_repo(repo_id, exist_ok=True, private=private)
# save to tmp and prepare for push
cache_path = tempfile.mkdtemp()
# save tokenizer.json
tokenizer_save_path = os.path.join(cache_path, subfolder, tokenizer_filename)
self.save(cache_path, pretty=True)
if commit_message is None:
commit_message = "Hezar: Upload tokenizer and config"
# upload config
self.config.push_to_hub(
repo_id=repo_id,
filename=config_filename,
subfolder=subfolder,
commit_message=commit_message,
)
# upload tokenizer
upload_file(
repo_id=repo_id,
path_or_fileobj=tokenizer_save_path,
repo_type="model",
path_in_repo=f"{subfolder}/{tokenizer_filename}",
commit_message=commit_message,
)
logger.log_upload_success(
name=f"{self.__class__.__name__}(name={self.config.name})",
target_path=os.path.join(repo_id, subfolder, tokenizer_filename),
)
@property
def model(self) -> "Model":
return self._tokenizer.model
@model.setter
def model(self, model: "Model"):
self._tokenizer.model = model # noqa
@property
def decoder(self) -> "Decoder":
return self._tokenizer.decoder
@decoder.setter
def decoder(self, decoder: "Decoder"):
self._tokenizer.decoder = decoder # noqa
@property
def padding(self):
return self._tokenizer.padding
@property
def truncation(self) -> dict:
return self._tokenizer.truncation
@property
def vocab(self):
return self._tokenizer.get_vocab(with_added_tokens=True)
@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
return self._tokenizer.get_vocab_size(with_added_tokens=False)
@property
def special_ids(self):
return [self.token_to_id(t) for t in self.special_tokens]
@property
def pad_token(self):
return self.config.pad_token
@property
def bos_token(self):
return self.config.bos_token
@property
def eos_token(self):
return self.config.eos_token
@property
def unk_token(self):
return self.config.unk_token
@property
def mask_token(self):
return self.config.mask_token
@property
def cls_token(self):
return self.config.cls_token
@property
def sep_token(self):
return self.config.sep_token
@property
def pad_token_id(self):
return self.token_to_id(self.config.pad_token)
@property
def bos_token_id(self):
return self.token_to_id(self.config.bos_token)
@property
def eos_token_id(self):
return self.token_to_id(self.config.eos_token)
@property
def unk_token_id(self):
return self.token_to_id(self.config.unk_token)
@property
def mask_token_id(self):
return self.token_to_id(self.config.mask_token)
@property
def cls_token_id(self):
return self.token_to_id(self.config.cls_token)
@property
def sep_token_id(self):
return self.token_to_id(self.config.sep_token)