diff --git a/NeoML/src/Dnn/Layers/ActivationLayers.cpp b/NeoML/src/Dnn/Layers/ActivationLayers.cpp index b097f5939..ff8c0fe28 100644 --- a/NeoML/src/Dnn/Layers/ActivationLayers.cpp +++ b/NeoML/src/Dnn/Layers/ActivationLayers.cpp @@ -230,28 +230,32 @@ CActivationDesc CLinearLayer::GetDesc() const void CLinearLayer::RunOnce() { - const int dataSize = outputBlobs[0]->GetDataSize(); - - if( inputBlobs[0]->GetDataType() == CT_Float ) { - linearRunOnce( inputBlobs[0]->GetData(), multiplier, freeTerm, dataSize, outputBlobs[0]->GetData() ); - } else { - linearRunOnce( inputBlobs[0]->GetData(), static_cast( multiplier ), - static_cast( freeTerm ), dataSize, outputBlobs[0]->GetData() ); + for( int i = 0; i < inputBlobs.Size(); ++i ) { + const int dataSize = outputBlobs[i]->GetDataSize(); + + if( inputBlobs[i]->GetDataType() == CT_Float ) { + linearRunOnce( inputBlobs[i]->GetData(), multiplier, freeTerm, dataSize, outputBlobs[i]->GetData() ); + } else { + linearRunOnce( inputBlobs[i]->GetData(), static_cast( multiplier ), + static_cast< int >( freeTerm ), dataSize, outputBlobs[i]->GetData() ); + } } } void CLinearLayer::BackwardOnce() { - CConstFloatHandle outputDiffPtr = outputDiffBlobs[0]->GetData(); - CFloatHandle inputDiffPtr = inputDiffBlobs[0]->GetData(); - int dataSize = outputDiffBlobs[0]->GetDataSize(); - - if( multiplier != 1.f ) { - CFloatHandleStackVar multiplierValue( MathEngine() ); - multiplierValue.SetValue( multiplier ); - MathEngine().VectorMultiply( outputDiffPtr, inputDiffPtr, dataSize, multiplierValue ); - } else if( outputDiffPtr != inputDiffPtr ) { - MathEngine().VectorCopy( inputDiffPtr, outputDiffPtr, dataSize ); + for( int i = 0; i < outputBlobs.Size(); ++i ) { + CConstFloatHandle outputDiffPtr = outputDiffBlobs[i]->GetData(); + CFloatHandle inputDiffPtr = inputDiffBlobs[i]->GetData(); + int dataSize = outputDiffBlobs[i]->GetDataSize(); + + if( multiplier != 1.f ) { + CFloatHandleStackVar multiplierValue( MathEngine() ); + multiplierValue.SetValue( multiplier ); + MathEngine().VectorMultiply( outputDiffPtr, inputDiffPtr, dataSize, multiplierValue ); + } else if( outputDiffPtr != inputDiffPtr ) { + MathEngine().VectorCopy( inputDiffPtr, outputDiffPtr, dataSize ); + } } }