Currently, the required grad happens in the gradient accumulation function, after the actual gradient calculation.
Changing this would saved some operations. For example when having data (that do not require a gradient), the gradient for all the datapoints is going to get calculated. This can be very expensive, especially for high dim data.
Currently, the required grad happens in the gradient accumulation function, after the actual gradient calculation.
Changing this would saved some operations. For example when having data (that do not require a gradient), the gradient for all the datapoints is going to get calculated. This can be very expensive, especially for high dim data.