Source code for hezar.preprocessors.tokenizers.wordpiece
from dataclasses import dataclass, field
from ...constants import DEFAULT_TOKENIZER_CONFIG_FILE, DEFAULT_TOKENIZER_FILE, Backends
from ...registry import register_preprocessor
from ...utils import is_backend_available
from .tokenizer import Tokenizer, TokenizerConfig
if is_backend_available(Backends.TOKENIZERS):
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import decoders, models, trainers
_required_backends = [
Backends.TOKENIZERS,
]
[docs]
@dataclass
class WordPieceConfig(TokenizerConfig):
name = "wordpiece_tokenizer"
truncation_side: str = "right"
stride: int = 0
padding_side: str = "right"
pad_to_multiple_of: int = 0
pad_token: str = "[PAD]"
unk_token: str = "[UNK]"
sep_token: str = "[SEP]"
cls_token: str = "[CLS]"
mask_token: str = "[MASK]"
pad_token_type_id: int = 0
additional_special_tokens: list[str] | None = None
wordpieces_prefix: str = "##"
vocab_size: int = 30000
min_frequency: int = 2
limit_alphabet: int = 1000
initial_alphabet: list = field(default_factory=list)
show_progress: bool = True
[docs]
@register_preprocessor("wordpiece_tokenizer", config_class=WordPieceConfig)
class WordPieceTokenizer(Tokenizer):
"""
A standard WordPiece tokenizer using 🤗HuggingFace Tokenizers
Args:
config: Preprocessor config for the tokenizer
**kwargs: Extra/manual config parameters
"""
required_backends = _required_backends
tokenizer_filename = DEFAULT_TOKENIZER_FILE
tokenizer_config_filename = DEFAULT_TOKENIZER_CONFIG_FILE
token_ids_name = "token_ids"
def __init__(self, config, tokenizer_file=None, **kwargs):
super().__init__(config, tokenizer_file=tokenizer_file, **kwargs)
[docs]
def build(self):
tokenizer = HFTokenizer(models.WordPiece(unk_token=self.config.unk_token)) # noqa
tokenizer.decoder = decoders.WordPiece(self.config.wordpieces_prefix) # noqa
return tokenizer
[docs]
def train(self, files: list[str], **train_kwargs):
"""Train the model using the given files"""
self.config.update(train_kwargs)
trainer = trainers.WordPieceTrainer(
vocab_size=self.config.vocab_size,
min_frequency=self.config.min_frequency,
limit_alphabet=self.config.limit_alphabet,
initial_alphabet=self.config.initial_alphabet,
special_tokens=self.config.special_tokens,
show_progress=self.config.show_progress,
continuing_subword_prefix=self.config.wordpieces_prefix,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(files, trainer=trainer)
[docs]
def train_from_iterator(self, dataset: list[str], **train_kwargs):
"""Train the model using the given files"""
self.config.update(train_kwargs)
trainer = trainers.WordPieceTrainer(
vocab_size=self.config.vocab_size,
min_frequency=self.config.min_frequency,
limit_alphabet=self.config.limit_alphabet,
initial_alphabet=self.config.initial_alphabet,
special_tokens=self.config.special_tokens,
show_progress=self.config.show_progress,
continuing_subword_prefix=self.config.wordpieces_prefix,
)
self._tokenizer.train_from_iterator(dataset, trainer=trainer, length=len(dataset))