Skip to content

Fix MaskGIT initialization error with newer version of the Transformers library#108

Open
AlienKevin wants to merge 1 commit intobytedance:mainfrom
AlienKevin:fix_lm_head_post_init
Open

Fix MaskGIT initialization error with newer version of the Transformers library#108
AlienKevin wants to merge 1 commit intobytedance:mainfrom
AlienKevin:fix_lm_head_post_init

Conversation

@AlienKevin
Copy link
Copy Markdown

@yucornetto @TACJu Recent version of the Transformers library breaks maskgit.py model initialization. Transformer's post_init function forcefully resets the output dimension of self.model.lm_head to vocab_size=self.target_codebook_size + self.condition_num_classes + 2, which causes dimension mismatch with the intended self.target_codebook_size:

[rank0]: Traceback (most recent call last):
[rank0]:   File "1d-tokenizer/sample_imagenet_titok.py", line 137, in <module>
[rank0]:     main()
[rank0]:     ~~~~^^
[rank0]:   File "1d-tokenizer/sample_imagenet_titok.py", line 79, in main
[rank0]:     titok_generator = demo_util.get_titok_generator(config)
[rank0]:   File "1d-tokenizer/demo_util.py", line 63, in get_titok_generator
[rank0]:     generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
[rank0]:     ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "1d-tokenizer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 2624, in load_state_dict
[rank0]:     raise RuntimeError(
[rank0]:     ...<3 lines>...
[rank0]:     )
[rank0]: RuntimeError: Error(s) in loading state_dict for ImageBert:
[rank0]:        size mismatch for model.lm_head.weight: copying a param with shape torch.Size([4096, 768]) from checkpoint, the shape in current model is torch.Size([5098, 768]).
[rank0]:        size mismatch for model.lm_head.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5098]).
[rank0]:[W818 00:31:23.366445854 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

My proposed fix is to initialize lm_head after self.model.post_init(), which should preserve the current initialization method for lm_head while ensuring the output dimension stay unchanged.

My testing environment
absl-py==2.3.1
accelerate==1.10.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
astunparse==1.6.3
beautifulsoup4==4.13.4
braceexpand==0.1.7
certifi==2025.8.3
charset-normalizer==3.4.3
click==8.2.1
diffusers==0.34.0
einops==0.8.1
filelock==3.19.1
flatbuffers==25.2.10
fsspec==2025.7.0
ftfy==6.3.1
gast==0.6.0
gdown==5.2.0
gitdb==4.0.12
gitpython==3.1.45
google-pasta==0.2.0
grpcio==1.74.0
h5py==3.14.0
hf-xet==1.1.7
huggingface-hub==0.34.4
idna==3.10
importlib-metadata==8.7.0
jinja2==3.1.6
keras==3.11.2
libclang==18.1.1
markdown==3.8.2
markdown-it-py==4.0.0
markupsafe==3.0.2
mdurl==0.1.2
ml-dtypes==0.5.3
mpmath==1.3.0
namex==0.1.0
networkx==3.5
numpy==2.3.2
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvtx-cu12==12.8.90
omegaconf==2.3.0
open-clip-torch==3.1.0
opt-einsum==3.4.0
optree==0.17.0
packaging==25.0
pillow==11.3.0
platformdirs==4.3.8
protobuf==6.32.0
psutil==7.0.0
pydantic==2.11.7
pydantic-core==2.33.2
pygments==2.19.2
pysocks==1.7.1
pyyaml==6.0.2
regex==2025.7.34
requests==2.32.4
rich==14.1.0
safetensors==0.6.2
scipy==1.16.1
sentry-sdk==2.35.0
setuptools==80.9.0
six==1.17.0
smmap==5.0.2
soupsieve==2.7
sympy==1.14.0
tensorboard==2.20.0
tensorboard-data-server==0.7.2
tensorflow==2.20.0
termcolor==3.1.0
timm==1.0.19
tokenizers==0.21.4
torch==2.8.0
torch-fidelity==0.3.0
torchinfo==1.8.0
torchvision==0.23.0
tqdm==4.67.1
transformers==4.55.2
triton==3.4.0
typing-extensions==4.14.1
typing-inspection==0.4.1
urllib3==2.5.0
wandb==0.21.1
wcwidth==0.2.13
webdataset==1.0.2
werkzeug==3.1.3
wheel==0.45.1
wrapt==1.17.3
zipp==3.23.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant