Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions compressai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,16 @@
from torch import Tensor

from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.latent_codecs import LatentCodec
from compressai.models.utils import remap_old_keys, update_registered_buffers
from compressai.latent_codecs import (
GainHyperpriorLatentCodec,
HyperpriorLatentCodec,
LatentCodec,
)
from compressai.models.utils import (
remap_legacy_hyperprior_keys,
remap_old_keys,
update_registered_buffers,
)

__all__ = [
"CompressionModel",
Expand Down Expand Up @@ -96,6 +104,9 @@ def load_state_dict(self, state_dict, strict=True):
if not any(x.startswith(name) for x in state_dict.keys()):
continue

if isinstance(module, (HyperpriorLatentCodec, GainHyperpriorLatentCodec)):
state_dict = remap_legacy_hyperprior_keys(name, state_dict)

if isinstance(module, EntropyBottleneck):
update_registered_buffers(
module,
Expand Down
21 changes: 21 additions & 0 deletions compressai/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,27 @@ def remap_subkey(s: str) -> str:
return new_state_dict


def remap_legacy_hyperprior_keys(module_name, state_dict):
old_prefix = f"{module_name}.hyper."
new_state_dict = OrderedDict()

for k, v in state_dict.items():
if k.startswith(f"{old_prefix}h_a."):
k = k.replace(f"{old_prefix}h_a.", f"{module_name}.h_a.", 1)
elif k.startswith(f"{old_prefix}h_s."):
k = k.replace(f"{old_prefix}h_s.", f"{module_name}.h_s.", 1)
elif k.startswith(f"{old_prefix}entropy_bottleneck."):
k = k.replace(
f"{old_prefix}entropy_bottleneck.",
f"{module_name}.z.entropy_bottleneck.",
1,
)

new_state_dict[k] = v

return new_state_dict


def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
Expand Down
Loading