diff --git a/segm/model/factory.py b/segm/model/factory.py index 199b834..75c71f8 100644 --- a/segm/model/factory.py +++ b/segm/model/factory.py @@ -59,11 +59,12 @@ def create_vit(model_cfg): drop_block_rate=None, ) - default_cfg["input_size"] = ( + default_cfg.update({"input_size": ( 3, model_cfg["image_size"][0], model_cfg["image_size"][1], - ) + )}) + model = VisionTransformer(**model_cfg) if backbone == "vit_base_patch8_384": path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")