From 44c77fe6cb505827bea87b8ca2aa8f175b4f748e Mon Sep 17 00:00:00 2001 From: studyingeugene Date: Fri, 3 Apr 2026 13:44:25 +0900 Subject: [PATCH] refactor: add backward compatibility with legacy hyperprior state dict --- compressai/models/base.py | 15 +++++++++++++-- compressai/models/utils.py | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/compressai/models/base.py b/compressai/models/base.py index 2fd63654..472979fd 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -38,8 +38,16 @@ from torch import Tensor from compressai.entropy_models import EntropyBottleneck, GaussianConditional -from compressai.latent_codecs import LatentCodec -from compressai.models.utils import remap_old_keys, update_registered_buffers +from compressai.latent_codecs import ( + GainHyperpriorLatentCodec, + HyperpriorLatentCodec, + LatentCodec, +) +from compressai.models.utils import ( + remap_legacy_hyperprior_keys, + remap_old_keys, + update_registered_buffers, +) __all__ = [ "CompressionModel", @@ -96,6 +104,9 @@ def load_state_dict(self, state_dict, strict=True): if not any(x.startswith(name) for x in state_dict.keys()): continue + if isinstance(module, (HyperpriorLatentCodec, GainHyperpriorLatentCodec)): + state_dict = remap_legacy_hyperprior_keys(name, state_dict) + if isinstance(module, EntropyBottleneck): update_registered_buffers( module, diff --git a/compressai/models/utils.py b/compressai/models/utils.py index 9e5f5120..1acac16e 100644 --- a/compressai/models/utils.py +++ b/compressai/models/utils.py @@ -147,6 +147,27 @@ def remap_subkey(s: str) -> str: return new_state_dict +def remap_legacy_hyperprior_keys(module_name, state_dict): + old_prefix = f"{module_name}.hyper." + new_state_dict = OrderedDict() + + for k, v in state_dict.items(): + if k.startswith(f"{old_prefix}h_a."): + k = k.replace(f"{old_prefix}h_a.", f"{module_name}.h_a.", 1) + elif k.startswith(f"{old_prefix}h_s."): + k = k.replace(f"{old_prefix}h_s.", f"{module_name}.h_s.", 1) + elif k.startswith(f"{old_prefix}entropy_bottleneck."): + k = k.replace( + f"{old_prefix}entropy_bottleneck.", + f"{module_name}.z.entropy_bottleneck.", + 1, + ) + + new_state_dict[k] = v + + return new_state_dict + + def conv(in_channels, out_channels, kernel_size=5, stride=2): return nn.Conv2d( in_channels,