Skip to content

[Bug] NNVariationalStrategy returns only diagonal covariance in eval mode #2740

@DiogoRibeiro7

Description

@DiogoRibeiro7

🐛 Bug

Suggested labels: bug, variational

NNVariationalStrategy currently returns only a diagonal predictive covariance in eval mode, even though the implementation comment says eval mode should return the full covariance.

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
from gpytorch.variational import MeanFieldVariationalDistribution
from gpytorch.variational.nearest_neighbor_variational_strategy import NNVariationalStrategy

class Model(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, k=3):
        variational_distribution = MeanFieldVariationalDistribution(inducing_points.size(-2))
        variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, k=k)
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        return gpytorch.distributions.MultivariateNormal(self.mean_module(x), self.covar_module(x))

inducing_points = torch.randn(16, 2)
model = Model(inducing_points)
model.eval()

test_x = torch.randn(4, 2)
out = model(test_x)
print(type(out.lazy_covariance_matrix))
print(out.covariance_matrix.shape)

** Stack trace/error message **

No exception is raised, but eval mode returns a diagonal covariance representation rather than a full predictive covariance.

Expected Behavior

In eval mode, NNVariationalStrategy.forward should return the full predictive covariance, or the implementation/docs should be updated to make the diagonal-only behavior explicit.

There is already an inline TODO in gpytorch/variational/nearest_neighbor_variational_strategy.py:

  • This method needs to return the full covariance in eval mode, not just the predictive variance.

System information

Please complete the following information:

  • GPyTorch Version: current main
  • PyTorch Version: current supported versions
  • Computer OS: N/A

Additional context

The relevant code path currently computes predictive_covar, immediately squeezes it to a variance tensor, and returns MultivariateNormal(..., DiagLinearOperator(predictive_var)) in eval mode.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions