Skip to content

[Bug] Multitask FixedNoise likelihood incorrectly computes data posterior #2687

@awang124

Description

@awang124

🐛 Bug

I'm performing multitask sparse variational GPR, with a fixed (N, T, T) noise tensor C. To make this work, I've subclassed FixedGaussianNoise to return a block-diagonal noise covariance, and modified the _shaped_noise_covar method of _MultitaskGaussianLikelihoodBase (to bypass learned noise operations). In prediction, model(X) and likelihood(model(X)) should be such that likelihood(model(X)).covariance_matrix - model(X).covariance_matrix equals torch.block_diag(*C). This is almost the case, except some blocks have their off-diagonal (crosscovariance) component equal to 1e-4, instead of its corresponding component in C.

To reproduce

** Code snippet to reproduce **

# INITIALIZE C
np.random.seed(0)
torch.manual_seed(0)
C = np.random.randn(10000, 2, 2)
C = C @ np.transpose(C, (0, 2, 1))
C = torch.tensor(C, dtype=torch.float32)

# DEFINE MULTITASK FIXED-NOISE LIKELIHOOD
class FixedNoise(gpytorch.likelihoods.noise_models.FixedGaussianNoise):
    def forward(self, *params, shape=None, noise=None, **kwargs):
        op = linear_operator.operators.DenseLinearOperator(self.noise)
        return linear_operator.operators.BlockDiagLinearOperator(op)

def _shaped_noise_covar(self, shape, add_noise=True, interleaved=True, *params, **kwargs):
    return self.noise_covar(*params, **kwargs)
gpytorch.likelihoods._MultitaskGaussianLikelihoodBase._shaped_noise_covar = _shaped_noise_covar

# INITIALIZE MODEL / LIKELIHOOD
model = MyModel()
noise_model = FixedNoise(C)
likelihood  = gpytorch.likelihoods._MultitaskGaussianLikelihoodBase(
    num_tasks=2, noise_covar=noise_model
)
likelihood.has_task_noise = False

# PREDICT
model.eval()
likelihood.eval()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    latent_post = model(X)
    data_post = likelihood(latent_post)

# EXAMPLE SUCCESSFUL BLOCK
print(torch.block_diag(*C)[8:10, 8:10])
print((data_post.covariance_matrix - latent_post.covariance_matrix)[8:10, 8:10])

# EXAMPLE FAILED BLOCK
print(torch.block_diag(*C)[10:12, 10:12])
print((data_post.covariance_matrix - latent_post.covariance_matrix)[10:12, 10:12])

Expected Behavior

My model defines task GPs independently, so latent_post.covariance_matrix will have its off-diagonal components as 0. So we should see the off-diagonal component of each block in likelihood(model(X)).covariance_matrix equal the off-diagonal component of its corresponding block in torch.block_diag(*C).

System information

Please complete the following information:
1.14.2
2.9.1+cu128
Fedora 42

Additional context

My full model definition for completeness:

class MyModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_tasks=2, num_inducing_points=100):
        inducing_points = torch.rand(num_tasks, num_inducing_points, 2)
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(
            batch_shape=torch.Size([num_tasks])
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
        )
    def forward(self, X):
        mean = self.mean_module(X)
        covar = self.covar_module(X)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions