Skip to content

Commit d4ad080

Browse files
committed
Fix SpatialLDA serialization to support model reconstruction and transformation
1 parent 3d2c1bf commit d4ad080

1 file changed

Lines changed: 28 additions & 6 deletions

File tree

spatialtissuepy/mcp/serialization.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,21 @@ def serialize_model(
369369
"neighborhood_radius": getattr(model, "neighborhood_radius", None),
370370
"neighborhood_method": getattr(model, "neighborhood_method", "radius"),
371371
"random_state": getattr(model, "random_state", None),
372+
"is_fitted": getattr(model, "_is_fitted", False),
372373
}
373374

374375
# Store fitted components if available
375-
if hasattr(model, "components_") and model.components_ is not None:
376-
result["components"] = model.components_.tolist()
377-
if hasattr(model, "cell_types_") and model.cell_types_ is not None:
378-
result["cell_types"] = list(model.cell_types_)
376+
# Check both topic_cell_type_matrix_ (SpatialLDA) and components_ (sklearn)
377+
components = getattr(model, "topic_cell_type_matrix_", None)
378+
if components is None and hasattr(model, "_lda_model") and model._lda_model is not None:
379+
components = getattr(model._lda_model, "components_", None)
380+
381+
if components is not None:
382+
result["components"] = components.tolist()
383+
384+
cell_types = getattr(model, "cell_types_", None)
385+
if cell_types is not None:
386+
result["cell_types"] = list(cell_types)
379387

380388
return result
381389

@@ -437,6 +445,7 @@ def deserialize_model(
437445
if model_type == "spatial_lda":
438446
try:
439447
from spatialtissuepy.lda import SpatialLDA
448+
from sklearn.decomposition import LatentDirichletAllocation
440449

441450
model = SpatialLDA(
442451
n_topics=data["n_topics"],
@@ -446,9 +455,22 @@ def deserialize_model(
446455
)
447456

448457
if data.get("components"):
449-
model.components_ = np.array(data["components"])
458+
components = np.array(data["components"])
459+
model.topic_cell_type_matrix_ = components
460+
461+
# Reconstruct sklearn model for transformation
462+
lda_model = LatentDirichletAllocation(
463+
n_components=data["n_topics"],
464+
random_state=data.get("random_state"),
465+
)
466+
lda_model.components_ = components
467+
model._lda_model = lda_model
468+
450469
if data.get("cell_types"):
451-
model.cell_types_ = np.array(data["cell_types"])
470+
model.cell_types_ = list(data["cell_types"])
471+
472+
if data.get("is_fitted"):
473+
model._is_fitted = True
452474

453475
return model
454476
except ImportError:

0 commit comments

Comments
 (0)