diff --git a/experiments/fashion/fashionclasses_macow.yaml b/experiments/fashion/fashionclasses_macow.yaml new file mode 100644 index 0000000..44529bb --- /dev/null +++ b/experiments/fashion/fashionclasses_macow.yaml @@ -0,0 +1,139 @@ +--- +__object__: src.usflows.explib.base.ExperimentCollection +name: fashion_ablation_macow +experiments: + - &exp_rad_logN + __object__: src.usflows.explib.hyperopt.HyperoptExperiment + name: cfair_full_normal + skip: false + device: cpu + scheduler: &scheduler + __object__: ray.tune.schedulers.ASHAScheduler + max_t: 1000000 + grace_period: 1000000 + reduction_factor: 2 + num_hyperopt_samples: &num_hyperopt_samples 1 + gpus_per_trial: &gpus_per_trial 0 + cpus_per_trial: &cpus_per_trial 1 + tuner_params: &tuner_params + metric: val_loss + mode: min + trial_config: + logging: + images: false + "image_shape": [28, 28] + dataset: &dataset + class: + __class__: src.usflow.explib.datasets.FashionMnistSplit + params: + space_to_depth_factor: 4 + dataloc: /home/faried/Projects/USFlows/data/fashion + label: 0 + epochs: &epochs 200000 + patience: &patience 1 + batch_size: &batch_size + __eval__: tune.choice([32]) + optim_cfg: &optim + optimizer: + __class__: src.usflows.sophia.SophiaG + params: + lr: + __eval__: 1e-3 + weight_decay: 0.0 + + model_cfg: + type: + __class__: src.usflows.flows.USFlow + params: + soft_training: + __eval__: tune.choice([False]) + training_noise_prior: + __object__: pyro.distributions.Uniform + low: + __eval__: 1e-20 + high: 0.01 + prior_scale: 1.0 + coupling_blocks: + __eval__: tune.choice([10]) + lu_transform: 1 + householder: 0 + conditioner_cls: + __class__: src.usflows.networks.ConvNet2D + conditioner_args: + c_in: 16 + c_hidden: + __eval__: tune.choice([32]) + num_layers: + __eval__: tune.choice([3]) + padding: same + kernel_size: 3 + rescale_hidden: 1 + normalize_layers: + __eval__: tune.choice([True]) + gating: + __eval__: tune.choice([True]) + in_dims: [16, 7, 7] + affine_conjugation: true + nonlinearity: + __eval__: tune.choice([torch.nn.ReLU()]) + base_distribution: + __object__: pyro.distributions.Normal + loc: + __eval__: torch.zeros([16, 7, 7]).to("cpu") + scale: + __eval__: torch.ones([1]).to("cpu") + - __overwrites__: *exp_rad_logN + name: fashion1_radial_logN + trial_config: + dataset: + params: + label: 1 + - __overwrites__: *exp_rad_logN + name: fashion2_radial_logN + trial_config: + dataset: + params: + label: 2 + - __overwrites__: *exp_rad_logN + name: fashion3_radial_logN + trial_config: + dataset: + params: + label: 3 + - __overwrites__: *exp_rad_logN + name: fashion4_radial_logN + trial_config: + dataset: + params: + label: 4 + - __overwrites__: *exp_rad_logN + name: fashion5_radial_logN + trial_config: + dataset: + params: + label: 5 + - __overwrites__: *exp_rad_logN + name: fashion6_radial_logN + trial_config: + dataset: + params: + label: 6 + - __overwrites__: *exp_rad_logN + name: fashion7_radial_logN + trial_config: + dataset: + params: + label: 7 + - __overwrites__: *exp_rad_logN + name: fashion8_radial_logN + trial_config: + dataset: + params: + label: 8 + - __overwrites__: *exp_rad_logN + name: fashion9_radial_logN + trial_config: + dataset: + params: + label: 9 + \ No newline at end of file diff --git a/experiments/fashion/fashionclasses_veriflow.yaml b/experiments/fashion/fashionclasses_veriflow.yaml new file mode 100644 index 0000000..0931266 --- /dev/null +++ b/experiments/fashion/fashionclasses_veriflow.yaml @@ -0,0 +1,149 @@ +--- +__object__: src.usflows.explib.base.ExperimentCollection +name: fashion_ablation_veriflow +experiments: + - &exp_rad_logN + __object__: src.usflows.explib.hyperopt.HyperoptExperiment + name: cfair_full_radial_logN + skip: false + device: cpu + scheduler: &scheduler + __object__: ray.tune.schedulers.ASHAScheduler + max_t: 1000000 + grace_period: 1000000 + reduction_factor: 2 + num_hyperopt_samples: &num_hyperopt_samples 1 + gpus_per_trial: &gpus_per_trial 0 + cpus_per_trial: &cpus_per_trial 1 + tuner_params: &tuner_params + metric: val_loss + mode: min + trial_config: + logging: + images: false + "image_shape": [28, 28] + dataset: &dataset + class: + __class__: src.usflows.explib.datasets.FashionMnistSplit + params: + space_to_depth_factor: 4 + dataloc: /home/faried/Projects/USFlows/data/fashion + label: 0 + epochs: &epochs 200000 + patience: &patience 1 + batch_size: &batch_size + __eval__: tune.choice([32]) + optim_cfg: &optim + optimizer: + __class__: src.usflows.sophia.SophiaG + params: + lr: + __eval__: 1e-3 + weight_decay: 0.0 + + model_cfg: + type: + __class__: src.usflows.flows.USFlow + params: + soft_training: + __eval__: tune.choice([False]) + training_noise_prior: + __object__: pyro.distributions.Uniform + low: + __eval__: 1e-20 + high: 0.01 + prior_scale: 1.0 + coupling_blocks: + __eval__: tune.choice([10]) + lu_transform: 1 + householder: 0 + conditioner_cls: + __class__: src.usflows.networks.ConvNet2D + conditioner_args: + c_in: 16 + c_hidden: + __eval__: tune.choice([32]) + num_layers: + __eval__: tune.choice([3]) + padding: same + kernel_size: 3 + rescale_hidden: 1 + normalize_layers: + __eval__: tune.choice([True]) + gating: + __eval__: tune.choice([True]) + in_dims: [16, 7, 7] + affine_conjugation: true + nonlinearity: + __eval__: tune.choice([torch.nn.ReLU()]) + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + device: cpu + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([16, 7, 7]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * 75 + rate: + __eval__: torch.rand([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp_rad_logN + name: fashion1_radial_logN + trial_config: + dataset: + params: + label: 1 + - __overwrites__: *exp_rad_logN + name: fashion2_radial_logN + trial_config: + dataset: + params: + label: 2 + - __overwrites__: *exp_rad_logN + name: fashion3_radial_logN + trial_config: + dataset: + params: + label: 3 + - __overwrites__: *exp_rad_logN + name: fashion4_radial_logN + trial_config: + dataset: + params: + label: 4 + - __overwrites__: *exp_rad_logN + name: fashion5_radial_logN + trial_config: + dataset: + params: + label: 5 + - __overwrites__: *exp_rad_logN + name: fashion6_radial_logN + trial_config: + dataset: + params: + label: 6 + - __overwrites__: *exp_rad_logN + name: fashion7_radial_logN + trial_config: + dataset: + params: + label: 7 + - __overwrites__: *exp_rad_logN + name: fashion8_radial_logN + trial_config: + dataset: + params: + label: 8 + - __overwrites__: *exp_rad_logN + name: fashion9_radial_logN + trial_config: + dataset: + params: + label: 9 + \ No newline at end of file diff --git a/experiments/mnist/mnist.yaml b/experiments/mnist/mnist.yaml index a2d2251..9da921d 100644 --- a/experiments/mnist/mnist.yaml +++ b/experiments/mnist/mnist.yaml @@ -6,7 +6,7 @@ experiments: __object__: src.explib.hyperopt.HyperoptExperiment name: mnist_full_radial_logN device: cpu - skip: true + skip: False scheduler: &scheduler __object__: ray.tune.schedulers.ASHAScheduler max_t: 1000000 @@ -14,7 +14,7 @@ experiments: reduction_factor: 2 num_hyperopt_samples: &num_hyperopt_samples 1 gpus_per_trial: &gpus_per_trial 0 - cpus_per_trial: &cpus_per_trial 1 + cpus_per_trial: &cpus_per_trial 16 tuner_params: &tuner_params metric: val_loss mode: min @@ -38,7 +38,7 @@ experiments: __class__: src.usflows.sophia.SophiaG params: lr: - __eval__: 1e-4 + __eval__: 1e-3 weight_decay: 0.0 model_cfg: @@ -54,7 +54,7 @@ experiments: high: 0.01 prior_scale: 1.0 coupling_blocks: - __eval__: tune.choice([10]) + __eval__: tune.choice([15]) lu_transform: 1 householder: 0 conditioner_cls: diff --git a/experiments/mnist/mnist_digits.yaml b/experiments/mnist/mnist_digits.yaml index 08a7fe7..ea6e996 100644 --- a/experiments/mnist/mnist_digits.yaml +++ b/experiments/mnist/mnist_digits.yaml @@ -1,155 +1,156 @@ --- __object__: src.explib.base.ExperimentCollection -name: mnist_digit_basedist_comparison +name: mnist_gigits_logN experiments: - - &digit_0 - __object__: src.explib.base.ExperimentCollection - name: mnist_basedist_comparison - experiments: - - &exp_nice_lu_laplace - __object__: src.explib.hyperopt.HyperoptExperiment - name: mnist_nice_lu_laplace - scheduler: &scheduler - __object__: ray.tune.schedulers.ASHAScheduler - max_t: 1000000 - grace_period: 1000000 - reduction_factor: 2 - num_hyperopt_samples: &num_hyperopt_samples 20 - gpus_per_trial: &gpus_per_trial 0 - cpus_per_trial: &cpus_per_trial 1 - tuner_params: &tuner_params - metric: val_loss - mode: min - trial_config: - dataset: &dataset - __object__: src.explib.datasets.MnistSplit - digit: 0 - scale: true - epochs: &epochs 200000 - patience: &patience 50 - batch_size: &batch_size + - &exp_rad_logN + __object__: src.explib.hyperopt.HyperoptExperiment + name: mnist0 + device: cpu + skip: false + scheduler: &scheduler + __object__: ray.tune.schedulers.ASHAScheduler + max_t: 1000000 + grace_period: 1000000 + reduction_factor: 2 + num_hyperopt_samples: &num_hyperopt_samples 1 + gpus_per_trial: &gpus_per_trial 0 + cpus_per_trial: &cpus_per_trial 16 + tuner_params: &tuner_params + metric: val_loss + mode: min + trial_config: + logging: + images: false + "image_shape": [28, 28] + dataset: &dataset + class: + __class__: src.explib.datasets.MnistSplit + params: + dataloc: /home/faried/Projects/USFlows/data/mnist + space_to_depth_factor: 4 + device: cpu + digit: 0 + epochs: 200000 + patience: 5 + batch_size: + __eval__: tune.choice([32]) + optim_cfg: + optimizer: + __class__: src.usflows.sophia.SophiaG + params: + lr: + __eval__: 1e-4 + weight_decay: 0.0 + + model_cfg: + type: + __class__: src.usflows.flows.USFlow + params: + soft_training: + __eval__: tune.choice([False]) + training_noise_prior: + __object__: pyro.distributions.Uniform + low: + __eval__: 1e-20 + high: 0.01 + prior_scale: 1.0 + coupling_blocks: + __eval__: tune.choice([5]) + lu_transform: 1 + householder: 0 + conditioner_cls: + __class__: src.usflows.networks.ConvNet2D + conditioner_args: + c_in: 16 + c_hidden: __eval__: tune.choice([32]) - optim_cfg: &optim - optimizer: - __class__: torch.optim.Adam - params: - lr: - __eval__: tune.loguniform(1e-4, 1e-2) - weight_decay: 0.0 - - model_cfg: - type: - __class__: &model src.veriflow.flows.NiceFlow - params: - soft_training: true - training_noise_prior: - __object__: pyro.distributions.Uniform - low: 0.0 - high: 0.001 - prior_scale: 1.0 - use_lu: true - coupling_layers: &coupling_layers - __eval__: tune.choice([2, 3, 4, 5]) - coupling_nn_layers: &coupling_nn_layers - __eval__: tune.choice([[w]*l for l in [1, 2] for w in [10, 20, 50, 100, 200]]) - nonlinearity: &nonlinearity - __eval__: tune.choice([torch.nn.ReLU()]) - split_dim: - __eval__: tune.choice([i for i in range(1, 51)]) - base_distribution: - __object__: pyro.distributions.Laplace - loc: - __eval__: torch.zeros(100) - scale: - __eval__: torch.ones(100) - - &exp_nice_lu_normal - __overwrites__: *exp_nice_lu_laplace - name: mnist_nice_lu_normal - trial_config: - model_cfg: - params: - base_distribution: - __exact__: - __object__: pyro.distributions.Normal - loc: - __eval__: torch.zeros(100) - scale: - __eval__: torch.ones(100) - - &exp_nice_rand_laplace - __overwrites__: *exp_nice_lu_laplace - name: mnist_nice_rand_laplace - trial_config: - model_cfg: - params: - use_lu: false - masktype: random - - &exp_nice_rand_normal - __overwrites__: *exp_nice_lu_laplace - name: mnist_nice_rand_normal - trial_config: - model_cfg: - params: - use_lu: false - masktype: random - base_distribution: - __exact__: - __object__: pyro.distributions.Normal - loc: - __eval__: torch.zeros(100) - scale: - __eval__: torch.ones(100) - - &digit_1 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 1 - - &digit_2 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 2 - - &digit_3 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 3 - - &digit_4 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 4 - - &digit_5 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 5 - - &digit_6 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 6 - - &digit_7 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 7 - - &digit_8 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 8 - - &digit_9 - __overwrites__: *digit_0 - experiments: - trial_config: - dataset: - digit: 9 - \ No newline at end of file + num_layers: + __eval__: tune.choice([3]) + padding: same + kernel_size: 3 + rescale_hidden: 1 + normalize_layers: + __eval__: tune.choice([True]) + gating: + __eval__: tune.choice([True]) + in_dims: [16, 7, 7] + affine_conjugation: true + nonlinearity: + __eval__: tune.choice([torch.nn.ReLU()]) + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + device: cpu + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([16, 7, 7]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.LogNormal + loc: + __eval__: torch.ones([1]).to("cpu") * 6 + scale: + __eval__: torch.ones([1]).to("cpu") * .35 + device: cpu + - + __overwrites__: *exp_rad_logN + name: mnist1 + trial_config: + dataset: + params: + digit: 1 + - + __overwrites__: *exp_rad_logN + name: mnist2 + trial_config: + dataset: + params: + digit: 2 + - + __overwrites__: *exp_rad_logN + name: mnist3 + trial_config: + dataset: + params: + digit: 3 + - + __overwrites__: *exp_rad_logN + name: mnist4 + trial_config: + dataset: + params: + digit: 4 + - + __overwrites__: *exp_rad_logN + name: mnist5 + trial_config: + dataset: + params: + digit: 5 + - + __overwrites__: *exp_rad_logN + name: mnist6 + trial_config: + dataset: + params: + digit: 6 + - + __overwrites__: *exp_rad_logN + name: mnist7 + trial_config: + dataset: + params: + digit: 7 + - + __overwrites__: *exp_rad_logN + name: mnist8 + trial_config: + dataset: + params: + digit: 8 + - + __overwrites__: *exp_rad_logN + name: mnist9 + trial_config: + dataset: + params: + digit: 9 diff --git a/experiments/synthetic/gaussian_mixture.yaml b/experiments/synthetic/gaussian_mixture.yaml new file mode 100644 index 0000000..b5e8289 --- /dev/null +++ b/experiments/synthetic/gaussian_mixture.yaml @@ -0,0 +1,406 @@ +--- +__object__: src.usflows.explib.base.ExperimentCollection +name: gaussian_mixture_experiments +experiments: + - &exp2d + __object__: src.usflows.explib.hyperopt.HyperoptExperiment + name: gaussian_mixture_2D + device: cpu + scheduler: &scheduler + __object__: ray.tune.schedulers.ASHAScheduler + max_t: 1000000 + grace_period: 1000000 + reduction_factor: 2 + num_hyperopt_samples: &num_hyperopt_samples 1 + gpus_per_trial: &gpus_per_trial 0 + cpus_per_trial: &cpus_per_trial 16 + tuner_params: &tuner_params + metric: val_loss + mode: min + trial_config: + logging: + images: false + "image_shape": [28, 28] + dataset: &dataset + class: + __class__: src.usflows.explib.datasets.DistributionSplit + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(2)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + num_train: 10000 + num_val: 2000 + num_test: 2000 + epochs: &epochs 200000 + patience: &patience 5 + batch_size: &batch_size + __eval__: tune.choice([32]) + optim_cfg: &optim + optimizer: + __class__: src.usflows.sophia.SophiaG + params: + lr: + __eval__: 1e-3 + weight_decay: 0.0 + model_cfg: + type: + __class__: src.usflows.flows.USFlow + params: + soft_training: + __eval__: tune.choice([False]) + training_noise_prior: + __object__: pyro.distributions.Uniform + low: + __eval__: 1e-20 + high: 0.01 + prior_scale: 1.0 + coupling_blocks: + __eval__: tune.choice([10]) + lu_transform: 1 + householder: 0 + conditioner_cls: + __class__: pyro.nn.DenseNN + conditioner_args: + input_dim: 2 + hidden_dims: [32, 32] + param_dims: [2] + in_dims: [2] + affine_conjugation: true + nonlinearity: + __eval__: tune.choice([torch.nn.ReLU()]) + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + device: cpu + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([2]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([2.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_3D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(3)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [3] + conditioner_args: + input_dim: 3 + hidden_dims: [32, 32] + param_dims: [3] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([3]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([3.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_4D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(4)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [4] + conditioner_args: + input_dim: 4 + hidden_dims: [32, 32] + param_dims: [4] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([4]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([4.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_5D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(5)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [5] + conditioner_args: + input_dim: 5 + hidden_dims: [32, 32] + param_dims: [5] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([5]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([5.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_6D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(6)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [6] + conditioner_args: + input_dim: 6 + hidden_dims: [32, 32] + param_dims: [6] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([6]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([6.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_7D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(7)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [7] + conditioner_args: + input_dim: 7 + hidden_dims: [32, 32] + param_dims: [7] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([7]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([7.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_8D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(8)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [8] + conditioner_args: + input_dim: 8 + hidden_dims: [32, 32] + param_dims: [8] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([8]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([8.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_9D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(9)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [9] + conditioner_args: + input_dim: 9 + hidden_dims: [32, 32] + param_dims: [9] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([9]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([9.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_10D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(10)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [10] + conditioner_args: + input_dim: 10 + hidden_dims: [32, 32] + param_dims: [10] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([10]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([10.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_100D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0] * 100, [1.0] * 100]) + covariance_matrix: + __eval__: torch.stack([torch.eye(100)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [100] + conditioner_args: + input_dim: 100 + hidden_dims: [32, 32] + param_dims: [100] + base_distribution: + __object__: src.usflows.distributions.RadialDistribution + p: + __eval__: float("1") + loc: + __eval__: torch.zeros([100]).to("cpu") + norm_distribution: + __object__: src.usflows.distributions.GammaMM + concentration: + __eval__: torch.rand([20]).to("cpu") * torch.sqrt(torch.tensor([100.0])) + rate: + __eval__: torch.ones([20]).to("cpu") + mixture_weights: + __eval__: torch.ones([20]).to("cpu") / 20 + device: cpu \ No newline at end of file diff --git a/experiments/synthetic/gaussian_mixture_standart_base.yaml b/experiments/synthetic/gaussian_mixture_standart_base.yaml new file mode 100644 index 0000000..7082d1e --- /dev/null +++ b/experiments/synthetic/gaussian_mixture_standart_base.yaml @@ -0,0 +1,290 @@ +--- +__object__: src.usflows.explib.base.ExperimentCollection +name: gaussian_mixture_experiments +experiments: + - &exp2d + __object__: src.usflows.explib.hyperopt.HyperoptExperiment + name: gaussian_mixture_2D + device: cpu + scheduler: &scheduler + __object__: ray.tune.schedulers.ASHAScheduler + max_t: 1000000 + grace_period: 1000000 + reduction_factor: 2 + num_hyperopt_samples: &num_hyperopt_samples 1 + gpus_per_trial: &gpus_per_trial 0 + cpus_per_trial: &cpus_per_trial 16 + tuner_params: &tuner_params + metric: val_loss + mode: min + trial_config: + logging: + images: false + "image_shape": [28, 28] + dataset: &dataset + class: + __class__: src.usflows.explib.datasets.DistributionSplit + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(2)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + num_train: 10000 + num_val: 2000 + num_test: 2000 + epochs: &epochs 200000 + patience: &patience 5 + batch_size: &batch_size + __eval__: tune.choice([32]) + optim_cfg: &optim + optimizer: + __class__: src.usflows.sophia.SophiaG + params: + lr: + __eval__: 1e-3 + weight_decay: 0.0 + model_cfg: + type: + __class__: src.usflows.flows.USFlow + params: + soft_training: + __eval__: tune.choice([False]) + training_noise_prior: + __object__: pyro.distributions.Uniform + low: + __eval__: 1e-20 + high: 0.01 + prior_scale: 1.0 + coupling_blocks: + __eval__: tune.choice([10]) + lu_transform: 1 + householder: 0 + conditioner_cls: + __class__: pyro.nn.DenseNN + conditioner_args: + input_dim: 2 + hidden_dims: [32, 32] + param_dims: [2] + in_dims: [2] + affine_conjugation: true + nonlinearity: + __eval__: tune.choice([torch.nn.ReLU()]) + base_distribution: + __object__: src.usflows.distributions.Normal + loc: + __eval__: torch.zeros([2]).to("cpu") + scale: + __eval__: torch.tensor(1.0).to("cpu") + device: cpu + - __overwrites__: *exp2d + name: gaussian_mixture_3D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(3)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [3] + conditioner_args: + input_dim: 3 + hidden_dims: [32, 32] + param_dims: [3] + base_distribution: + __object__: src.usflows.distributions.Normal + loc: + __eval__: torch.zeros([3]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_4D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(4)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [4] + conditioner_args: + input_dim: 4 + hidden_dims: [32, 32] + param_dims: [4] + base_distribution: + loc: + __eval__: torch.zeros([4]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_5D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(5)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [5] + conditioner_args: + input_dim: 5 + hidden_dims: [32, 32] + param_dims: [5] + base_distribution: + loc: + __eval__: torch.zeros([5]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_6D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(6)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [6] + conditioner_args: + input_dim: 6 + hidden_dims: [32, 32] + param_dims: [6] + base_distribution: + loc: + __eval__: torch.zeros([6]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_7D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(7)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [7] + conditioner_args: + input_dim: 7 + hidden_dims: [32, 32] + param_dims: [7] + base_distribution: + loc: + __eval__: torch.zeros([7]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_8D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(8)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [8] + conditioner_args: + input_dim: 8 + hidden_dims: [32, 32] + param_dims: [8] + base_distribution: + loc: + __eval__: torch.zeros([8]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_9D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(9)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [9] + conditioner_args: + input_dim: 9 + hidden_dims: [32, 32] + param_dims: [9] + base_distribution: + loc: + __eval__: torch.zeros([9]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_10D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) + covariance_matrix: + __eval__: torch.stack([torch.eye(10)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [10] + conditioner_args: + input_dim: 10 + hidden_dims: [32, 32] + param_dims: [10] + base_distribution: + loc: + __eval__: torch.zeros([10]).to("cpu") + - __overwrites__: *exp2d + name: gaussian_mixture_100D + trial_config: + dataset: + params: + distribution: + __object__: src.usflows.distributions.GMM + loc: + __eval__: torch.tensor([[-1.0] * 100, [1.0] * 100]) + covariance_matrix: + __eval__: torch.stack([torch.eye(100)] * 2) + mixture_weights: + __eval__: torch.tensor([0.5, 0.5]) + model_cfg: + params: + in_dims: [100] + conditioner_args: + input_dim: 100 + hidden_dims: [32, 32] + param_dims: [100] + base_distribution: + loc: + __eval__: torch.zeros([100]).to("cpu") \ No newline at end of file diff --git a/scripts/eval.py b/scripts/eval.py new file mode 100644 index 0000000..7382121 --- /dev/null +++ b/scripts/eval.py @@ -0,0 +1,590 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +from scipy.stats import wasserstein_distance, ks_2samp, norm +from sklearn.neighbors import KernelDensity +from sklearn.preprocessing import StandardScaler +from scipy.stats import chi2, binned_statistic +from sklearn.metrics import mutual_info_score +import scipy.stats as stats +from scipy.stats import binomtest, wilcoxon +from sklearn.neighbors import KernelDensity +import pandas as pd + +from src.usflows.explib.config_parser import from_checkpoint +from src.usflows.distributions import Independent +import os +import torch +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +class RadialFlowEvaluator: + def __init__(self, flow, data, device='cpu'): + """ + Evaluator for USFlow models with RadialDistribution base distribution. + + Args: + flow: Trained USFlow model with RadialDistribution base + data: Dataset tensor for evaluation + device: Device for computation + """ + self.flow = flow.to(device) + self.data = data.to(device) + self.device = device + + self.dim = torch.prod(torch.tensor(self.data.shape[1:])).item() + + # Precompute latent representations + with torch.no_grad(): + if isinstance(self.flow.base_distribution, Independent): + loc = self.flow.base_distribution._base_distribution.loc.to(device) + else: + loc = self.flow.base_distribution.loc.to(device) + self.latents = self.flow.backward(self.data) - loc + self.latents = self.latents.view(self.latents.shape[0], -1) + # Get p-norm from base distribution + self.p = flow.base_distribution.p + + def wasserstein_norm_distance(self, n_samples=10000): + """ + Compute Wasserstein distance between: + 1. Norm distribution of base distribution + 2. Empirical p-norms of latent representations + + Returns: + wasserstein_dist: Wasserstein distance + """ + # Get empirical latent norms + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + + # Sample from base norm distribution + base_norm_dist = self.flow.base_distribution.norm_distribution + sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() + + # Compute Wasserstein distance + return wasserstein_distance(latent_norms, sample_norms) + + def ks_norm_statistic(self, n_samples=10000): + """ + Compute Kolmogorov-Smirnov statistic for norm distributions. + + Returns: + ks_stat: KS statistic + p_value: Associated p-value + """ + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + base_norm_dist = self.flow.base_distribution.norm_distribution + sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() + + return ks_2samp(latent_norms, sample_norms) + + def qq_plot_norms(self, ax=None, n_samples=10000): + """ + Generate QQ-plot comparing: + 1. Quantiles of empirical latent norms + 2. Quantiles of base norm distribution samples + """ + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + base_norm_dist = self.flow.base_distribution.norm_distribution + sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() + + latent_quantiles = np.quantile(latent_norms, np.linspace(0, 1, 100)) + sample_quantiles = np.quantile(sample_norms, np.linspace(0, 1, 100)) + + if ax is None: + fig, ax = plt.subplots(figsize=(8, 6)) + + ax.scatter(sample_quantiles, latent_quantiles, alpha=0.7) + min_val = min(sample_quantiles.min(), latent_quantiles.min()) + max_val = max(sample_quantiles.max(), latent_quantiles.max()) + ax.plot([min_val, max_val], [min_val, max_val], 'r--') + ax.set_title('QQ-plot of Latent Norms') + ax.set_xlabel('Base Distribution Quantiles') + ax.set_ylabel('Data Latent Quantiles') + return ax + + def kde_plot_norms(self, ax=None, n_samples=10000): + """ + Generate KDE plots comparing: + 1. Empirical latent norms distribution + 2. Base norm distribution + """ + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + base_norm_dist = self.flow.base_distribution.norm_distribution + sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() + + # Create KDE models + kde_latent = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(latent_norms.reshape(-1, 1)) + kde_base = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(sample_norms.reshape(-1, 1)) + + # Evaluate on grid + x_grid = np.linspace( + min(latent_norms.min(), sample_norms.min()), + max(latent_norms.max(), sample_norms.max()), + 1000 + )[:, np.newaxis] + + log_dens_latent = kde_latent.score_samples(x_grid) + log_dens_base = kde_base.score_samples(x_grid) + + if ax is None: + fig, ax = plt.subplots(figsize=(10, 6)) + + ax.plot(x_grid, np.exp(log_dens_latent), label='Data Latents') + ax.plot(x_grid, np.exp(log_dens_base), label='Base Distribution') + ax.set_title('KDE of Norm Distributions') + ax.set_xlabel('Norm Value') + ax.set_ylabel('Density') + ax.legend() + return ax + + def pp_plot_norms(self, ax=None, n_samples=10000): + """ + Generate PP-plot comparing: + 1. Empirical CDF of latent norms + 2. Theoretical CDF of base norm distribution + + Args: + ax: Matplotlib axis (optional) + n_samples: Number of samples for theoretical distribution + + Returns: + ax: Matplotlib axis + """ + # Get empirical latent norms + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + + # Compute empirical CDF + n = len(latent_norms) + empirical_cdf = np.arange(1, n+1) / n + sorted_norms = np.sort(latent_norms) + + # Get theoretical CDF (if available) + base_norm_dist = self.flow.base_distribution.norm_distribution + if hasattr(base_norm_dist, 'cdf'): + # Use analytical CDF if available + theoretical_cdf = base_norm_dist.cdf( + torch.tensor(sorted_norms).to(self.device) + ).detach().cpu().numpy() + else: + # Approximate via sampling + sample_norms = base_norm_dist.sample((n_samples,)).detach().cpu().numpy() + sample_sorted = np.sort(sample_norms) + theoretical_cdf = np.searchsorted(sample_sorted, sorted_norms) / n_samples + + # Create plot + if ax is None: + fig, ax = plt.subplots(figsize=(8, 6)) + + ax.scatter(theoretical_cdf, empirical_cdf, alpha=0.7) + ax.plot([0, 1], [0, 1], 'r--') + ax.set_title('PP-plot of Norm Distributions') + ax.set_xlabel('Theoretical CDF (Base Distribution)') + ax.set_ylabel('Empirical CDF (Data Latents)') + ax.grid(True) + + return ax + + def binned_uniformity_test(self, n_bins=10): + """ + Binned uniformity test for latent directions. + Computes chi-squared statistic for binned directional data. + + Returns: + chi2_stat: Chi-squared statistic + p_value: Associated p-value + """ + # Normalize to unit sphere + directions = self.latents / torch.norm(self.latents, p=self.p, dim=1, keepdim=True) + directions = directions.cpu().numpy() + + # Create bins in each dimension + bin_edges = np.linspace(-1, 1, n_bins + 1) + bin_indices = np.zeros(len(directions), dtype=int) + + # Multi-dimensional binning + for dim in range(self.dim): + bin_indices_dim = np.digitize(directions[:, dim], bin_edges) - 1 + bin_indices += bin_indices_dim * (n_bins ** dim) + + # Count bins + unique_bins, counts = np.unique(bin_indices, return_counts=True) + n_observed = len(unique_bins) + + # Expected counts (uniform distribution) + total_bins = n_bins ** self.dim + expected = len(directions) / total_bins + + # Chi-squared test + chi2_stat = np.sum((counts - expected) ** 2 / expected) + p_value = 1 - chi2.cdf(chi2_stat, df=n_observed - 1) + + return chi2_stat, p_value + + def hs_independence_test(self, n_permutations=1000): + """ + Hilbert-Schmidt Independence Criterion for: + H0: Norm and direction are independent + + Returns: + hsic_value: HSIC statistic + p_value: Estimated p-value via permutation test + """ + # Compute norms and directions + norms = torch.norm(self.latents, p=self.p, dim=1).unsqueeze(1) + directions = self.latents / norms + + # Center and scale + norms = (norms - norms.mean()) / norms.std() + directions = (directions - directions.mean(dim=0)) / directions.std(dim=0) + + # Compute kernels + K_n = self._rbf_kernel(norms) + K_d = self._rbf_kernel(directions) + + # Center kernels + n = len(norms) + H = torch.eye(n) - torch.ones(n, n) / n + K_n = H @ K_n @ H + K_d = H @ K_d @ H + + # Compute HSIC + hsic_value = torch.trace(K_n @ K_d) / (n * n) + + # Permutation test for p-value + permuted_values = [] + for _ in range(n_permutations): + perm_idx = torch.randperm(n) + K_d_perm = K_d[perm_idx][:, perm_idx] + permuted_values.append(torch.trace(K_n @ K_d_perm).item()) + + permuted_values = np.array(permuted_values) / (n * n) + p_value = (permuted_values >= hsic_value.item()).mean() + + return hsic_value.item(), p_value + + def _rbf_kernel(self, X, sigma=None): + """Compute RBF kernel matrix""" + n = X.shape[0] + X_norm = torch.sum(X**2, dim=1).reshape(-1, 1) + pairwise_dist = X_norm + X_norm.T - 2 * torch.mm(X, X.T) + + if sigma is None: + sigma = torch.median(pairwise_dist[pairwise_dist > 0]).sqrt() + + gamma = 1.0 / (2 * sigma**2) + K = torch.exp(-gamma * pairwise_dist) + return K + + def test_uniformity_simplex(self, alpha=0.05, method='energy', n_samples_ref=1000, n_boot=1000): + """ + Test uniformity of normalized absolute latents on the simplex. + + Args: + alpha: Significance level + method: 'energy' for energy distance test, 'bhattacharyya' for transformed residuals test + n_samples_ref: Number of reference samples for energy distance + n_boot: Number of bootstrap samples for p-value calculation + + Returns: + p_value: Computed p-value for uniformity test + reject: Boolean indicating rejection of uniformity + """ + if self.p != 1: + raise ValueError("Uniformity test requires L1 norm (p=1), current p={}".format(self.p)) + + # Compute absolute values and normalize to simplex + abs_latents = torch.abs(self.latents) + row_sums = abs_latents.sum(dim=1, keepdim=True) + valid_rows = (row_sums > 1e-8).squeeze() + + if valid_rows.sum() < 10: # Ensure sufficient valid samples + raise ValueError("Insufficient non-zero latent vectors for uniformity test") + + u = abs_latents[valid_rows] / row_sums[valid_rows] + u_np = u.cpu().numpy() + + if method == 'energy': + return self._energy_uniformity_test(u_np, alpha, n_samples_ref, n_boot) + elif method == 'bhattacharyya': + return self._bhattacharyya_uniformity_test(u_np, alpha) + else: + raise ValueError("Unknown method: {}".format(method)) + + def _energy_uniformity_test(self, u, alpha, n_samples_ref, n_boot): + """Energy distance test for uniformity on simplex""" + d = u.shape[1] + n = u.shape[0] + + # Generate reference uniform sample + ref = self._simulate_uniform_simplex(n_samples_ref, d) + + # Compute observed energy distance + stat_obs = self._energy_distance(u, ref) + + # Bootstrap distribution under null + stat_boot = [] + for _ in range(n_boot): + boot_sample = self._simulate_uniform_simplex(n, d) + stat_boot.append(self._energy_distance(boot_sample, ref)) + + # Calculate p-value + p_value = np.mean(np.array(stat_boot) >= stat_obs) + reject = p_value < alpha + return p_value, reject + + def _bhattacharyya_uniformity_test(self, u, alpha): + """Bhattacharyya transformation test for uniformity""" + # Transform to negative logs + y = -np.log(u) + + # Compute residuals (centered logs) + residuals = y - y.mean(axis=1, keepdims=True) + + # Flatten residuals and test against standard Gumbel + flat_residuals = residuals.flatten() + ks_stat, p_value = stats.kstest(flat_residuals, 'gumbel_r') + reject = p_value < alpha + return p_value, reject + + def _simulate_uniform_simplex(self, n, d): + """Generate uniform samples on simplex using exponential distribution""" + exp_samples = np.random.exponential(scale=1.0, size=(n, d)) + row_sums = exp_samples.sum(axis=1, keepdims=True) + return exp_samples / row_sums + + def _energy_distance(self, X, Y): + """Compute energy distance between samples X and Y""" + n = X.shape[0] + m = Y.shape[0] + + # Compute pairwise distances + xx = np.sum(X**2, axis=1) + yy = np.sum(Y**2, axis=1) + xy = np.dot(X, Y.T) + + d_xx = xx[:, None] + xx[None, :] - 2 * np.dot(X, X.T) + d_yy = yy[:, None] + yy[None, :] - 2 * np.dot(Y, Y.T) + d_xy = xx[:, None] + yy[None, :] - 2 * xy + + term1 = np.sum(np.sqrt(d_xy)) / (n * m) + term2 = np.sum(np.sqrt(d_xx)) / (n * n) + term3 = np.sum(np.sqrt(d_yy)) / (m * m) + + return 2 * term1 - term2 - term3 + + def test_sign_symmetry(self, alpha=0.05, method='sign', combine='stouffer'): + """ + Test sign symmetry with options for high-dimensional aggregation. + + Args: + alpha: Significance level + method: 'sign' or 'wilcoxon' + combine: 'fisher', 'stouffer', or None for Bonferroni + + Returns: + result: Dictionary containing p-values and rejection decision + """ + if self.p != 1: + raise ValueError("Sign symmetry test requires L1 norm (p=1), current p={}".format(self.p)) + + p_values = [] + z_scores = [] # For Stouffer's method + + # Compute p-values for each dimension + for j in range(self.latents.shape[1]): + coord = self.latents[:, j].cpu().numpy() + + if method == 'sign': + n_pos = (coord > 0).sum() + test_result = binomtest(n_pos, len(coord), p=0.5, alternative='two-sided') + p_val = test_result.pvalue + z_scores.append((n_pos - len(coord)/2) / np.sqrt(len(coord)/4)) + + elif method == 'wilcoxon': + _, p_val = wilcoxon(coord, zero_method='wilcox', alternative='two-sided') + # For Fisher only (Stouffer not recommended with Wilcoxon in high-d) + z_scores.append(norm.ppf(1 - p_val/2) * np.sign(np.median(coord))) + + p_values.append(p_val) + + # Handle different combination methods + combined_p = None + if combine == 'fisher': + chi2_stat = -2 * np.sum(np.log(p_values)) + df = 2 * len(p_values) + combined_p = 1 - chi2.cdf(chi2_stat, df) + reject = combined_p < alpha + + elif combine == 'stouffer': + if method != 'sign': + raise ValueError("Stouffer method requires sign test") + z_combined = np.sum(z_scores) / np.sqrt(len(z_scores)) + combined_p = 2 * (1 - norm.cdf(np.abs(z_combined))) # Two-sided + reject = combined_p < alpha + + else: # Bonferroni + per_test_alpha = alpha / self.dim + reject = any(p < per_test_alpha for p in p_values) + + return { + 'p_values': p_values, + 'reject': reject, + 'combined_p': combined_p, + 'method': f"{method} with {combine}" if combine else f"{method} with Bonferroni" + } + + def test_l1_radial_symmetry(self, alpha=0.05, sign_method='wilcoxon', + sign_combine='fisher', uniform_method='energy'): + """ + Combined test with improved high-dimensional handling. + + Args: + alpha: Overall significance level + sign_method: 'sign' or 'wilcoxon' + sign_combine: 'fisher', 'stouffer', or None + uniform_method: 'energy' or 'bhattacharyya' + + Returns: + result: Dictionary with test outcomes + """ + if self.p != 1: + raise ValueError("L1-radial test requires p=1, current p={}".format(self.p)) + + # Test sign symmetry with alpha/2 + sign_result = self.test_sign_symmetry( + alpha=alpha/2, + method=sign_method, + combine=sign_combine + ) + + # Test uniformity with alpha/2 + uniformity_pval, uniformity_reject = self.test_uniformity_simplex( + alpha=alpha/2, method=uniform_method + ) + + # Combine results + l1_radial_rejected = sign_result['reject'] or uniformity_reject + + return { + 'sign_pvals': sign_result['p_values'], + 'sign_reject': sign_result['reject'], + 'sign_combined_p': sign_result['combined_p'], + 'sign_method': sign_result['method'], + 'uniformity_pval': uniformity_pval, + 'uniformity_reject': uniformity_reject, + 'l1_radial_rejected': l1_radial_rejected + } + +def pp_plot_multiple_norms(evaluators, labels, colors=None, n_samples=10000): + """ + Plot multiple PP-curves on the same axis. + + Args: + evaluators: List of RadialFlowEvaluator instances. + labels: List of labels for each model. + colors: Optional list of colors. + n_samples: Number of samples for theoretical CDF. + """ + fig, ax = plt.subplots(figsize=(8, 6)) + + for i, evaluator in enumerate(evaluators): + label = labels[i] + color = None if colors is None else colors[i % len(colors)] + _pp_plot_single(evaluator, ax, n_samples, label, color) + + ax.plot([0, 1], [0, 1], 'k--', label="y = x") + ax.set_title('PP-plot of Norm Distributions') + ax.set_xlabel('Theoretical CDF (Base Distribution)') + ax.set_ylabel('Empirical CDF (Data Latents)') + ax.grid(True) + ax.legend() + fig.tight_layout() + fig.savefig("pp_plot_combined-mixure.png", dpi=300) + fig.savefig("pp_plot_combined-mixure.pgf", dpi=300) + plt.close(fig) + + +def _pp_plot_single(evaluator, ax, n_samples, label=None, color=None): + """ + Plot a single evaluator on a shared axis. + """ + latent_norms = torch.norm(evaluator.latents, p=evaluator.p, dim=1).cpu().numpy() + n = len(latent_norms) + empirical_cdf = np.arange(1, n + 1) / n + sorted_norms = np.sort(latent_norms) + + base_norm_dist = evaluator.flow.base_distribution.norm_distribution + if hasattr(base_norm_dist, 'cdf'): + theoretical_cdf = base_norm_dist.cdf( + torch.tensor(sorted_norms).to(evaluator.device) + ).detach().cpu().numpy() + else: + sample_norms = base_norm_dist.sample((n_samples,)).detach().cpu().numpy() + sample_sorted = np.sort(sample_norms) + theoretical_cdf = np.searchsorted(sample_sorted, sorted_norms) / n_samples + + ax.plot(theoretical_cdf, empirical_cdf, label=label, color=color, alpha=0.8) + +if __name__ == '__main__': + + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + + plt.style.use('ggplot') + base_dir = "/home/faried/Projects/USFlows/reports/mnist_ablation_best_veriflow" + subfolders = sorted(os.listdir(base_dir)) + colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] # Add more if needed + + evaluators = [] + labels = [] + + for subfolder in subfolders: + model_dir = os.path.join(base_dir, subfolder) + if not os.path.isdir(model_dir): + continue + i = int(subfolder[-1]) + + # Locate model files + pkl_files = sorted([f for f in os.listdir(model_dir) if f.endswith(".pkl")]) + pt_files = sorted([f for f in os.listdir(model_dir) if f.endswith(".pt")]) + + if not pkl_files or not pt_files: + print(f"Skipping {model_dir} (missing files)") + continue + + pkl_path = os.path.join(model_dir, pkl_files[-1]) + pt_path = os.path.join(model_dir, pt_files[-1]) + model = from_checkpoint(pkl_path, pt_path) + + # Load test set + from src.usflows.explib.datasets import MnistDequantized + mnisti = MnistDequantized(dataloc="/home/faried/Projects/USFlows/data/mnist", space_to_depth_factor=4, digit=i, train=False)[:1000][0] + + evaluator = RadialFlowEvaluator(model, mnisti) + evaluators.append(evaluator) + + # Label derived from folder name + labels.append(f"MNIST Digit ${i}$") + + if evaluators: + pp_plot_multiple_norms(evaluators, labels, colors=colors) + print("Saved combined PP-plot to 'pp_plot_combined.png'") + else: + print("No valid models found.") + + + + + + + diff --git a/src/usflows/distributions.py b/src/usflows/distributions.py index b236c06..61cb7c0 100644 --- a/src/usflows/distributions.py +++ b/src/usflows/distributions.py @@ -54,20 +54,22 @@ def log_prob(self, x: torch.Tensor) -> torch.Tensor: class Chi(Distribution): arg_constraints = {"df": constraints.positive} - support = constraints.positive + support = constraints.positive # This will be updated to include scale has_enumerate_support = False - def __init__(self, df: int, validate_args=None): + def __init__(self, df: int, scale: float = 1.0, validate_args=None): """ Initialize the Chi distribution with degrees of freedom `df`. Args: df (Tensor): degrees of freedom. + scale (float): scale parameter. validate_args (bool, optional): Whether to validate input parameters. Default: None. """ self.chi2 = Chi2(df) self.df = df + self.scale = scale super(Chi, self).__init__( - self.chi2._batch_shape, self.chi2._event_shape, validate_args=validate_args + self.chi2._batch_shape, self.chi2._event_shape, validate_args=validate_args # This will be updated to include scale ) def sample(self, sample_shape=torch.Size()): @@ -78,18 +80,18 @@ def sample(self, sample_shape=torch.Size()): Returns: Tensor: A sample of the specified shape. """ - return torch.sqrt(self.chi2.sample(sample_shape)) + return self.scale * torch.sqrt(self.chi2.sample(sample_shape)) def log_prob(self, value): """ Calculate the log probability of a given value. Args: value (Tensor): The value at which to evaluate the log probability. - Returns: - Tensor: The log probability of the value. + Returns: Tensor: The log probability of the value. """ + value = value / self.scale y = value**2 - return self.chi2.log_prob(y) + torch.log(value * 2) + return self.chi2.log_prob(y) + torch.log(value * 2) - torch.log(torch.tensor(self.scale)) def cdf(self, value): """ @@ -98,7 +100,8 @@ def cdf(self, value): value (Tensor): The value at which to evaluate the CDF. Returns: Tensor: The CDF of the value. - """ + """ + value = value / self.scale y = value**2 return self.chi2.cdf(y) @@ -108,7 +111,7 @@ def entropy(self): Returns: Tensor: The entropy of the distribution. """ - return self.chi2.entropy() / 2 + torch.log(torch.tensor(2)) + return self.chi2.entropy() / 2 + torch.log(torch.tensor(2)) + torch.log(torch.tensor(self.scale)) class DistributionModule(Module): @@ -224,9 +227,14 @@ def __init__( self.to(device) def _get_distribution_params(self) -> Dict[str, torch.Tensor]: + if self.scale_unconstrained.dim() == 0: + # If scale is a scalar, we need to expand it to match loc's shape + scale = softplus(self.scale_unconstrained).expand_as(self.loc) + else: + scale = softplus(self.scale_unconstrained) return { "loc": self.loc, - "scale": softplus(self.scale_unconstrained) + "scale": scale } class Categorical(DistributionModule): @@ -242,26 +250,6 @@ def __init__( def _get_distribution_params(self) -> Dict[str, torch.Tensor]: return {"logits": self.logits} -class GMM(DistributionModule): - def __init__( - self, - loc: torch.Tensor, - scale: torch.Tensor, - mixture_weights: torch.Tensor, - device: str = "cpu", - ): - super().__init__(torch.distributions.MixtureSameFamily) - self.normal_batch = Normal(loc, scale, n_batch_dims=1) - self.mixture_distribution = Categorical(mixture_weights) - self.to(device) - - def _get_distribution_params(self) -> Dict[str, Distribution]: - return { - "mixture_distribution": self.mixture_distribution.distribution, - "component_distribution": self.normal_batch.distribution - } - - class UniformUnitLpBall(torch.distributions.Distribution): """Implements a uniform distribution on the unit ball.""" @@ -517,7 +505,7 @@ def log_prob(self, x: torch.Tensor) -> torch.Tensor: event_dims = tuple(range(x.dim() - len(self.event_shape), x.dim())) r = x.norm(p=self.p, dim=event_dims) - log_prob_norm = self.norm_distribution.log_prob(r.unsqueeze(-1)) + log_prob_norm = self.norm_distribution.log_prob(r.unsqueeze(-1)).squeeze(-1) log_dV = self.log_delta_volume(self.p, r) return log_prob_norm - log_dV @@ -785,15 +773,17 @@ def _get_distribution_params(self): params = self._get_constrained_params() # Permute component dimension to last - permute_order = list(range(1, params[0].dim())) + [0] - params_perm = [p.permute(permute_order) for p in params] - + #params_perm = [] + #for i, name in enumerate(self.param_names): + # permute_order = list(range(1, params[i].dim())) + [0] + # params_perm.append(params[i].permute(permute_order)) + # Create component distribution - comp_dist = self.component_distribution_class(**dict(zip(self.param_names, params_perm))) + comp_dist = self.component_distribution_class(**dict(zip(self.param_names, params))) # Permute mixture logits - logits_perm = self.mixture_logits.permute(permute_order) - mix_dist = dist.Categorical(logits=logits_perm) + #logits_perm = self.mixture_logits.permute(permute_order) + mix_dist = dist.Categorical(logits=self.mixture_logits) return { "mixture_distribution": mix_dist, @@ -804,6 +794,30 @@ def _get_distribution_params(self): def distribution(self) -> Distribution: return super().distribution + +class GMM(MixtureModel): + def __init__( + self, + loc: torch.Tensor, + covariance_matrix: torch.Tensor, + mixture_weights: torch.Tensor, + device: str = "cpu", + ): + param_constraints = { + "loc": constraints.real, + "covariance_matrix": constraints.positive_definite + } + + super().__init__( + dist.MultivariateNormal, + ["loc", "covariance_matrix"], + param_constraints, + loc, + covariance_matrix, + mixture_weights=mixture_weights, + device=device + ) + class LogNormalMM(MixtureModel): """Mixture of Log-Normal distributions.""" def __init__(self, loc, scale, mixture_weights, device="cpu"): @@ -834,4 +848,3 @@ def __init__(self, scale, concentration, mixture_weights, device="cpu"): mixture_weights=mixture_weights, device=device ) - diff --git a/src/usflows/explib/config_parser.py b/src/usflows/explib/config_parser.py index a5fa6b2..68132e3 100644 --- a/src/usflows/explib/config_parser.py +++ b/src/usflows/explib/config_parser.py @@ -1,7 +1,7 @@ from copy import deepcopy from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union from pickle import load import yaml @@ -9,6 +9,7 @@ # Convenience import for direct access in config files via "__eval__" from ray import tune import torch +import pyro def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]) -> Dict[str, Any]: """Updates the dictionary d with the dictionary u. diff --git a/src/usflows/explib/datasets.py b/src/usflows/explib/datasets.py index f6d93b7..3da1440 100644 --- a/src/usflows/explib/datasets.py +++ b/src/usflows/explib/datasets.py @@ -205,7 +205,7 @@ def __init__( """ super().__init__(*args, **kwargs) if isinstance(generator, str): - generator = GENERATORS[generator] + generator = GENERATORS[generator] self.dataset = generator(**params)[0] @@ -521,3 +521,79 @@ def get_test(self) -> torch.utils.data.Dataset: def get_val(self) -> torch.utils.data.Dataset: return self.val + +class DistributionDataset(torch.utils.data.Dataset): + """ + Dataset that generates samples from a given distribution. + """ + + def __init__( + self, + distribution: torch.distributions.Distribution, + num_samples: int, + device: torch.device = None, + ): + """ + Initialize dataset with a distribution and number of samples. + + Args: + distribution: Distribution to sample from + num_samples: Number of samples to generate + device: Device to store samples on + """ + super().__init__() + self.distribution = distribution + self.num_samples = num_samples + self.device = device + self.data = self.distribution.sample((num_samples,)).to(device) + + # Dummy labels for compatibility + self.labels = torch.zeros(num_samples, dtype=torch.long, device=device) + + def __getitem__(self, index: int): + return self.data[index], self.labels[index] + + def __len__(self): + return self.num_samples + + def to(self, device: torch.device): + """ + Move dataset to a different device. + + Args: + device: Device to move samples to + """ + self.data = self.data.to(device) + self.labels = self.labels.to(device) + self.device = device + return self + + + +class DistributionSplit(SimpleSplit): + """ + Data split that generates train/val/test from a distribution. + """ + + def __init__( + self, + distribution: torch.distributions.Distribution, + num_train: int, + num_val: int, + num_test: int, + device: torch.device = None, + ): + """ + Create train/val/test splits from a distribution. + + Args: + distribution: Distribution to sample from + num_train: Number of training samples + num_val: Number of validation samples + num_test: Number of test samples + device: Device to store samples on + """ + train = DistributionDataset(distribution, num_train, device) + val = DistributionDataset(distribution, num_val, device) + test = DistributionDataset(distribution, num_test, device) + super().__init__(train, test, val) \ No newline at end of file diff --git a/src/usflows/explib/eval.py b/src/usflows/explib/eval.py index e0f9e80..03af51f 100644 --- a/src/usflows/explib/eval.py +++ b/src/usflows/explib/eval.py @@ -1,3 +1,4 @@ +from typing import Optional import numpy as np import torch import matplotlib.pyplot as plt @@ -10,8 +11,10 @@ from scipy.stats import binomtest, wilcoxon from sklearn.neighbors import KernelDensity +from src.usflows.distributions import RadialDistribution + class RadialFlowEvaluator: - def __init__(self, flow, data, device='cpu'): + def __init__(self, flow, data, device='cpu', p: Optional[float] = None, norm_distribution: Optional[torch.distributions.Distribution] = None, loc: Optional[torch.Tensor] = None): """ Evaluator for USFlow models with RadialDistribution base distribution. @@ -26,13 +29,33 @@ def __init__(self, flow, data, device='cpu'): self.dim = torch.prod(torch.tensor(self.data.shape[1:])).item() + if isinstance(flow.base_distribution, RadialDistribution): + # Get p-norm from base distribution + self.p = flow.base_distribution.p + self.norm_distribution = flow.base_distribution.norm_distribution + else: + if p is None: + raise ValueError("p-norm must be specified for non-RadialDistribution base distributions") + if not isinstance(p, (int, float)): + raise TypeError("p must be an integer or float") + if p <= 0: + raise ValueError("p must be a positive number") + self.p = p + self.norm_distribution = norm_distribution + + if hasattr(flow.base_distribution, 'loc'): + self.loc = flow.base_distribution.loc.to(device) + else: + if loc is None: + raise ValueError("loc must be specified for non-RadialDistribution base distributions") + self.loc = loc.to(device) + # Precompute latent representations with torch.no_grad(): - self.latents = self.flow.backward(self.data) - flow.base_distribution.loc.to(device) + self.latents = self.flow.backward(self.data) - self.loc self.latents = self.latents.view(self.latents.shape[0], -1) - # Get p-norm from base distribution - self.p = flow.base_distribution.p - + + def wasserstein_norm_distance(self, n_samples=10000): """ Compute Wasserstein distance between: @@ -46,7 +69,7 @@ def wasserstein_norm_distance(self, n_samples=10000): latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() # Sample from base norm distribution - base_norm_dist = self.flow.base_distribution.norm_distribution + base_norm_dist = self.norm_distribution sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() # Compute Wasserstein distance @@ -61,7 +84,7 @@ def ks_norm_statistic(self, n_samples=10000): p_value: Associated p-value """ latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() - base_norm_dist = self.flow.base_distribution.norm_distribution + base_norm_dist = self.norm_distribution sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() return ks_2samp(latent_norms, sample_norms) @@ -72,8 +95,20 @@ def qq_plot_norms(self, ax=None, n_samples=10000): 1. Quantiles of empirical latent norms 2. Quantiles of base norm distribution samples """ + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + plt.style.use('ggplot') + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() - base_norm_dist = self.flow.base_distribution.norm_distribution + base_norm_dist = self.norm_distribution sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() latent_quantiles = np.quantile(latent_norms, np.linspace(0, 1, 100)) @@ -97,9 +132,22 @@ def kde_plot_norms(self, ax=None, n_samples=10000): 1. Empirical latent norms distribution 2. Base norm distribution """ - latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() - base_norm_dist = self.flow.base_distribution.norm_distribution - sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + plt.style.use('ggplot') + + with torch.no_grad(): + latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() + base_norm_dist = self.norm_distribution + sample_norms = base_norm_dist.sample((n_samples,)).cpu().numpy() # Create KDE models kde_latent = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(latent_norms.reshape(-1, 1)) @@ -139,6 +187,18 @@ def pp_plot_norms(self, ax=None, n_samples=10000): Returns: ax: Matplotlib axis """ + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + plt.style.use('ggplot') + # Get empirical latent norms latent_norms = torch.norm(self.latents, p=self.p, dim=1).cpu().numpy() @@ -148,10 +208,10 @@ def pp_plot_norms(self, ax=None, n_samples=10000): sorted_norms = np.sort(latent_norms) # Get theoretical CDF (if available) - base_norm_dist = self.flow.base_distribution.norm_distribution - if hasattr(base_norm_dist.distribution, 'cdf'): + base_norm_dist = self.norm_distribution + if hasattr(base_norm_dist, 'cdf'): # Use analytical CDF if available - theoretical_cdf = base_norm_dist.distribution.cdf( + theoretical_cdf = base_norm_dist.cdf( torch.tensor(sorted_norms).to(self.device) ).detach().cpu().numpy() else: @@ -462,4 +522,84 @@ def test_l1_radial_symmetry(self, alpha=0.05, sign_method='wilcoxon', 'uniformity_pval': uniformity_pval, 'uniformity_reject': uniformity_reject, 'l1_radial_rejected': l1_radial_rejected - } \ No newline at end of file + } + + def nll_norm_scatter_plot(self, ref_distribution, ax=None, n_samples=10000): + """ + Scatter plot of log-probabilities of latent norms vs base distribution. + + Args: + ref_distribution: Reference distribution for nll computation + ax: Matplotlib axis (optional) + n_samples: Number of samples for base distribution + """ + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + plt.style.use('ggplot') + + if ax is None: + fig, ax = plt.subplots() + + # Sample from the reference distribution + base_samples = ref_distribution.sample((n_samples,)).to(self.device) + base_samples = base_samples.view(base_samples.shape[0], -1) + + # Compute log-probabilities + with torch.no_grad(): + nlls = -ref_distribution.log_prob(base_samples).cpu().numpy() + latent_norms = (self.flow.backward(base_samples) - self.loc).norm(p=self.p, dim=1).cpu().numpy() + + # Scatter plot + ax.scatter(nlls, latent_norms, alpha=0.5) + ax.set_xlabel("Negative Log-Likelihood") + ax.set_ylabel("Latent Norm") + ax.set_title("Negative Log-Likelihood vs Latent Norm") + + return ax + + def logprob_reference_scatter_plot(self, ref_distribution, ax=None, n_samples=10000): + """ + Scatter estimated log-probs against reference distribution log-probs. + Args: + ref_distribution: Reference distribution for log-prob computation + ax: Matplotlib axis (optional) + n_samples: Number of samples for base distribution + """ + plt.rcParams.update({ + "pgf.texsystem": "pdflatex", + "text.usetex": False, + "pgf.rcfonts": False, + "font.size": 14, + "axes.labelsize": 16, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 12 + }) + plt.style.use('ggplot') + + if ax is None: + fig, ax = plt.subplots() + + # Sample from the reference distribution + base_samples = ref_distribution.sample((n_samples,)).to(self.device) + + # Compute log-probabilities + with torch.no_grad(): + ref_log_probs = ref_distribution.log_prob(base_samples).cpu().numpy() + learned_log_probs = self.flow.log_prob(base_samples).cpu().numpy() + + # Scatter plot + ax.scatter(ref_log_probs, learned_log_probs, alpha=0.5) + ax.set_xlabel("Reference Log-Probability") + ax.set_ylabel("Estimated Log-Probability") + ax.set_title("Log-Probability Comparison") + + return ax \ No newline at end of file