@@ -369,13 +369,21 @@ def serialize_model(
369369 "neighborhood_radius" : getattr (model , "neighborhood_radius" , None ),
370370 "neighborhood_method" : getattr (model , "neighborhood_method" , "radius" ),
371371 "random_state" : getattr (model , "random_state" , None ),
372+ "is_fitted" : getattr (model , "_is_fitted" , False ),
372373 }
373374
374375 # Store fitted components if available
375- if hasattr (model , "components_" ) and model .components_ is not None :
376- result ["components" ] = model .components_ .tolist ()
377- if hasattr (model , "cell_types_" ) and model .cell_types_ is not None :
378- result ["cell_types" ] = list (model .cell_types_ )
376+ # Check both topic_cell_type_matrix_ (SpatialLDA) and components_ (sklearn)
377+ components = getattr (model , "topic_cell_type_matrix_" , None )
378+ if components is None and hasattr (model , "_lda_model" ) and model ._lda_model is not None :
379+ components = getattr (model ._lda_model , "components_" , None )
380+
381+ if components is not None :
382+ result ["components" ] = components .tolist ()
383+
384+ cell_types = getattr (model , "cell_types_" , None )
385+ if cell_types is not None :
386+ result ["cell_types" ] = list (cell_types )
379387
380388 return result
381389
@@ -437,6 +445,7 @@ def deserialize_model(
437445 if model_type == "spatial_lda" :
438446 try :
439447 from spatialtissuepy .lda import SpatialLDA
448+ from sklearn .decomposition import LatentDirichletAllocation
440449
441450 model = SpatialLDA (
442451 n_topics = data ["n_topics" ],
@@ -446,9 +455,22 @@ def deserialize_model(
446455 )
447456
448457 if data .get ("components" ):
449- model .components_ = np .array (data ["components" ])
458+ components = np .array (data ["components" ])
459+ model .topic_cell_type_matrix_ = components
460+
461+ # Reconstruct sklearn model for transformation
462+ lda_model = LatentDirichletAllocation (
463+ n_components = data ["n_topics" ],
464+ random_state = data .get ("random_state" ),
465+ )
466+ lda_model .components_ = components
467+ model ._lda_model = lda_model
468+
450469 if data .get ("cell_types" ):
451- model .cell_types_ = np .array (data ["cell_types" ])
470+ model .cell_types_ = list (data ["cell_types" ])
471+
472+ if data .get ("is_fitted" ):
473+ model ._is_fitted = True
452474
453475 return model
454476 except ImportError :
0 commit comments