Skip to content

Commit 59dc807

Browse files
Validate tokenizer and model alignment before training
1 parent 58fa181 commit 59dc807

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

torchtitan/models/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,45 @@ def get_moe_model_nparams_and_flops(
468468
nparams = nparams - nparams_embedding
469469

470470
return nparams, num_flops_per_token
471+
472+
473+
def validate_tokenizer_model_alignment(
474+
tokenizer: "BaseTokenizer | None",
475+
model_args: "BaseModelArgs",
476+
) -> None:
477+
"""
478+
Validate that tokenizer configuration matches model configuration.
479+
480+
Args:
481+
tokenizer: Tokenizer instance to validate. Can be None.
482+
model_args: Model arguments object containing configuration to validate against.
483+
484+
Raises:
485+
ValueError: If tokenizer and model configurations don't match.
486+
"""
487+
if tokenizer is None:
488+
return
489+
490+
# Validate vocab_size
491+
if hasattr(model_args, "vocab_size"):
492+
tokenizer_vocab_size = tokenizer.get_vocab_size()
493+
model_vocab_size = model_args.vocab_size
494+
if tokenizer_vocab_size != model_vocab_size:
495+
raise ValueError(
496+
f"Tokenizer vocab_size ({tokenizer_vocab_size}) does not match "
497+
f"model vocab_size ({model_vocab_size}). "
498+
f"This mismatch will cause training errors. "
499+
f"Please ensure the tokenizer and model configuration are aligned."
500+
)
501+
502+
# Validate eos_id
503+
if hasattr(model_args, "eos_id"):
504+
tokenizer_eos_id = getattr(tokenizer, "eos_id", None)
505+
model_eos_id = model_args.eos_id
506+
if tokenizer_eos_id is not None and tokenizer_eos_id != model_eos_id:
507+
raise ValueError(
508+
f"Tokenizer eos_id ({tokenizer_eos_id}) does not match "
509+
f"model eos_id ({model_eos_id}). "
510+
f"This mismatch may cause training errors. "
511+
f"Please ensure the tokenizer and model configuration are aligned."
512+
)

torchtitan/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
2727
from torchtitan.distributed import ParallelDims, utils as dist_utils
28+
from torchtitan.models.utils import validate_tokenizer_model_alignment
2829
from torchtitan.protocols.model_converter import build_model_converters
2930
from torchtitan.tools import utils
3031
from torchtitan.tools.logging import init_logger, logger
@@ -134,6 +135,8 @@ def __init__(self, job_config: JobConfig):
134135
model_args.update_from_config(job_config)
135136
self.model_args = model_args
136137

138+
validate_tokenizer_model_alignment(self.tokenizer, model_args)
139+
137140
logger.info(
138141
f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}"
139142
)

0 commit comments

Comments
 (0)