Skip to content

Possible error with parallel computation #11

@lla1dlaw

Description

@lla1dlaw

Firstly I would like to say bravo. This is one of if not the cleanest and most well featured complex valued neural network abstractions that I have come across by far. I am currently doing research using a different complex valued neural network library called complexPyTorch (available: https://github.com/wavefrontshaping/complexPyTorch) with similar functionality to yours. However, I ran into a problem when using torch.DataParallel() on my Universities HPC to allow for the use of multiple GPUs during training:

The core problem is that torch.nn.DataParallel splits a data batch across multiple GPUs, but when updating the batch normalization running statistics (the saved mean and variance), it fails to correctly aggregate these statistics back from all GPUs. Instead, the final model on the main GPU only retains the running statistics calculated from its own small sub-batch, making them an incorrect representation of the overall dataset and causing poor performance during evaluation.

I was able to get around this by implementing my own custom BatchNorm2d module. My custom layer in custom_complex_layers.py (available: https://github.com/lla1dlaw/cvnn/blob/3d6c87c9ae14bdd049f7b26f967c8e23d23268af/torch_net/custom_complex_layers.py) solves the DataParallel problem by avoiding a separate, pre-compiled backend function and instead performing all the normalization math using basic tensor operations directly within the forward method. Each layer on each GPU calculates the mean and covariance using only the sub-batch of data it receives. It then immediately uses those local statistics to normalize the activations for that specific forward pass. This ensures the normalization is always correct for the data on each respective device, allowing the model to train stably. While the running statistics I save for evaluation are still only an approximation based on the data seen by the main GPU, this self-contained approach is what makes it compatible with the DataParallel wrapper.

I am still somewhat new to the field of complex valued neural networks, so there may definitely be some bugs/incorrect implementations in my program, but this solution allowed for dramatically increased training speed through parallel hardware acceleration. This issue may also be isolated to the complexPyTorch package that I am using, but I thought that it would be a good issue to bring to your attention nonetheless.

Thank you so much for your efforts and your time!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions