Skip to content

Bug in QRNN_LatentGates forward? #6

@ifsheldon

Description

@ifsheldon

I think the below line is buggy because we have a batch dimension n here, and the norm is calculated across samples, giving a scalar. I think the norm should be calculated for each sample separately, meaning we should have a norm of shape (n,).

norm = tc.einsum('na,na->digit', psi, psi.conj())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions