diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 0b3d499..3a580e6 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -459,7 +459,8 @@ def from_dataset(cls, if hasattr(dataset, 'intensities') and dataset.intensities is not None: intensities = dataset.intensities.to(dtype=Ws.dtype)[:,...] weights = t.sqrt(intensities) - Ws *= (weights / t.mean(weights)) + Ws *= (weights / t.mean(weights)).reshape( + (len(weights),) + (1,)*(Ws.ndim - 1)) if hasattr(dataset, 'mask') and dataset.mask is not None: mask = dataset.mask.to(t.bool) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index db05a75..c398be0 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -45,6 +45,28 @@ def test_center_probe(lab_ptycho_cxi): rtol=1e-3 ) +def test_lab_ptycho_data_loading(lab_ptycho_cxi): + + print('\nTesting a few unusual data loading scenarios.') + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(lab_ptycho_cxi) + + # Test that it will properly load an initialization for the weights + # from the intensities with OPRP on + dataset.intensities = t.rand(len(dataset)) + + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=4, + dm_rank=1, + ) + + # And test the case without OPRP + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=2, + ) + + @pytest.mark.slow def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): @@ -71,8 +93,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): print('Running reconstruction on provided reconstruction_device,', reconstruction_device) - model.to(device=reconstruction_device) - dataset.get_as(device=reconstruction_device) + #model.to(device=reconstruction_device) + #dataset.get_as(device=reconstruction_device) for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): print(model.report())