Source code for hezar.constants

"""
Home to all constant variables in Hezar
"""

import os
from enum import Enum


HEZAR_HUB_ID = "hezarai"
HEZAR_CACHE_DIR = os.getenv("HEZAR_CACHE_DIR", f'{os.path.expanduser("~")}/.cache/hezar')

DEFAULT_MODEL_FILE = "model.pt"
DEFAULT_MODEL_CONFIG_FILE = "model_config.yaml"
DEFAULT_TRAINER_SUBFOLDER = "train"
DEFAULT_TRAINER_CONFIG_FILE = "train_config.yaml"
DEFAULT_TRAINER_CSV_LOG_FILE = "training_logs.csv"
DEFAULT_TRAINER_STATE_FILE = "trainer_state.yaml"
DEFAULT_OPTIMIZER_FILE = "optimizer.pt"
DEFAULT_LR_SCHEDULER_FILE = "lr_scheduler.pt"
DEFAULT_PREPROCESSOR_SUBFOLDER = "preprocessor"
DEFAULT_NORMALIZER_CONFIG_FILE = "normalizer_config.yaml"
DEFAULT_IMAGE_PROCESSOR_CONFIG_FILE = "image_processor_config.yaml"
DEFAULT_FEATURE_EXTRACTOR_CONFIG_FILE = "feature_extractor_config.yaml"
DEFAULT_TOKENIZER_FILE = "tokenizer.json"
DEFAULT_TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml"
DEFAULT_DATASET_CONFIG_FILE = "dataset_config.yaml"
DEFAULT_EMBEDDING_FILE = "embedding.bin"
DEFAULT_EMBEDDING_CONFIG_FILE = "embedding_config.yaml"
DEFAULT_EMBEDDING_SUBFOLDER = "embedding"

TQDM_BAR_FORMAT = "{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}"


[docs] class ExplicitEnum(str, Enum): def __str__(self): return self.value
[docs] @classmethod def list(cls): return [x.value for x in cls.__members__.values()]
[docs] class Backends(ExplicitEnum): """ All required dependency packages and libraries. Note that the values here must be the exact module names used for importing, for example if you set PILLOW the value must be `PIL` not `pillow`, `pil`, etc. """ PYTORCH = "torch" TORCHVISION = "torchvision" TRANSFORMERS = "transformers" DATASETS = "datasets" TOKENIZERS = "tokenizers" ACCELERATE = "accelerate" TENSORBOARD = "tensorboard" SOUNDFILE = "soundfile" LIBROSA = "librosa" WANDB = "wandb" GENSIM = "gensim" PILLOW = "PIL" OPENCV = "cv2" JIWER = "jiwer" NLTK = "nltk" SCIKIT = "sklearn" SEQEVAL = "seqeval" ROUGE = "rouge_score"
[docs] class TaskType(ExplicitEnum): AUDIO_CLASSIFICATION = "audio_classification" BACKBONE = "backbone" IMAGE2TEXT = "image2text" LANGUAGE_MODELING = "language_modeling" MASK_FILLING = "mask_filling" SEQUENCE_LABELING = "sequence_labeling" SPEECH_RECOGNITION = "speech_recognition" TEXT_CLASSIFICATION = "text_classification" TEXT_DETECTION = "text_detection" TEXT_GENERATION = "text_generation"
[docs] class ConfigType(ExplicitEnum): BASE = "base" MODEL = "model" DATASET = "dataset" PREPROCESSOR = "preprocessor" EMBEDDING = "embedding" TRAINER = "trainer" OPTIMIZER = "optimizer" CRITERION = "criterion" LR_SCHEDULER = "lr_scheduler" METRIC = "metric"
[docs] class RegistryType(ExplicitEnum): MODEL = "model" DATASET = "dataset" PREPROCESSOR = "preprocessor" EMBEDDING = "embedding" TRAINER = "trainer" OPTIMIZER = "optimizer" CRITERION = "criterion" LR_SCHEDULER = "lr_scheduler" METRIC = "metric"
[docs] class LossType(ExplicitEnum): L1 = "l1" NLL = "nll" NLL_2D = "nll_2d" POISSON_NLL = "poisson_nll" GAUSSIAN_NLL = "gaussian_nll" MSE = "mse" BCE = "bce" BCE_WITH_LOGITS = "bce_with_logits" CROSS_ENTROPY = "cross_entropy" TRIPLE_MARGIN = "triple_margin" CTC = "ctc"
[docs] class PrecisionType(ExplicitEnum): NO = "no" FP8 = "fp8" FP16 = "fp16" BF16 = "bf16"
[docs] class OptimizerType(ExplicitEnum): ADAM = "adam" ADAMW = "adamw" SDG = "sdg"
[docs] class LRSchedulerType(ExplicitEnum): CONSTANT = "constant" LAMBDA = "lambda" STEP = "step" REDUCE_ON_PLATEAU = "reduce_on_plateau" MULTI_STEP = "multi_step" ONE_CYCLE = "one_cycle" LINEAR = "linear" EXPONENTIAL = "exponential" CYCLIC = "cyclic" SEQUENTIAL = "sequential" POLYNOMIAL = "polynomial" COSINE_ANEALING = "cosine_anealing"
[docs] class SplitType(ExplicitEnum): TRAIN = "train" EVAL = "eval" VALID = "validation" TEST = "test"
[docs] class MetricType(ExplicitEnum): ACCURACY = "accuracy" F1 = "f1" RECALL = "recall" PRECISION = "precision" SEQEVAL = "seqeval" CER = "cer" WER = "wer" BLEU = "bleu" ROUGE = "rouge"
[docs] class RepoType(ExplicitEnum): DATASET = "dataset" MODEL = "model"
[docs] class ImageType(ExplicitEnum): NUMPY = "numpy" PILLOW = "pillow" TORCH = "torch"
[docs] class ChannelsAxisSide(ExplicitEnum): FIRST = "first" LAST = "last"
[docs] class PaddingType(ExplicitEnum): MAX_LENGTH = "max_length" LONGEST = "longest"
[docs] class Color(ExplicitEnum): HEADER = "\033[95m" NORMAL = "\033[0m" BOLD = "\033[1m" UNDERLINE = "\033[4m" ITALIC = "\33[3m" BLUE = "\033[94m" CYAN = "\033[96m" GREEN = "\033[92m" YELLOW = "\033[93m" RED = "\033[91m" GREY = "\33[90m"