Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ hest = { git = "https://github.com/mahmoodlab/HEST.git", rev = "2c777630b1e8a74d

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel.force-include]
"conf" = "seal/conf"
"cache" = "seal/cache"
19 changes: 17 additions & 2 deletions seal/models/encoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,22 @@
import shutil
import os

from pathlib import Path
import argparse

def find_config_yaml():
# 1. Check for local override first (preserves original behavior)
local_config = Path("conf/config.yaml")
if local_config.exists():
return str(local_config)

# 2. Fall back to the bundled package config
package_config = Path(__file__).resolve().parent.parent / "conf" / "config.yaml"
if package_config.exists():
return str(package_config)

# 3. Last resort fallback to avoid a silent failure
raise FileNotFoundError("Could not find config.yaml locally or in the installed package.")

def get_constants(norm='imagenet'):
IMAGENET_MEAN = [0.485, 0.456, 0.406]
Expand Down Expand Up @@ -781,7 +796,7 @@ def load_img_model_from_checkpoint(
import argparse
from seal.models.load_model import ModelMixin

default_config_path = argparse.Namespace(config='conf/config.yaml')
default_config_path = argparse.Namespace(config=find_config_yaml())

if checkpoint_path is None:
print(f"Warning: No image checkpoint provided for checkpoint_dir={checkpoint_dir}")
Expand Down Expand Up @@ -842,7 +857,7 @@ def load_gene_model_from_checkpoint(
"""
from seal.models.load_model import ModelMixin

default_config_path = argparse.Namespace(config='conf/config.yaml')
default_config_path = argparse.Namespace(config=find_config_yaml())
model_config = update_config(default_config_path)
default_config = update_config(default_config_path)

Expand Down
15 changes: 12 additions & 3 deletions seal/models/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,21 @@
import torch.nn as nn
from seal.models.da_model import AdversarialDiscriminator
from seal.models.encoder_factory import encoder_factory

from pathlib import Path

HF_API_KEY = os.getenv("HF_API_KEY")

def find_organ_ids():
rel_path = "cache/organ_ids.json"
id_file = Path(rel_path)

if not id_file.exists():
id_file = Path(__file__).resolve().parent.parent / rel_path
if not id_file.exists():
raise FileNotFoundError(f'Could not locate {rel_path} locally or in package.')

with open(id_file, 'r') as f:
return json.load(f) # You can return this directly without assigning it to a variable

class ModelMixin():

Expand Down Expand Up @@ -533,8 +543,7 @@ def __init__(self, encoder, out_dim, rec_dim, eval_transforms, precision,
if self.use_adapter:
self.adapter = self.encoder.adapter

with open("cache/organ_ids.json", 'r') as f:
organ_ids = json.load(f)
organ_ids = find_organ_ids()
num_organs = len(organ_ids)

if self.organ_token:
Expand Down