diff --git a/platforms/hip/include/HipKernels.h b/platforms/hip/include/HipKernels.h index 0c317b9..c1adf84 100644 --- a/platforms/hip/include/HipKernels.h +++ b/platforms/hip/include/HipKernels.h @@ -360,6 +360,18 @@ class HipCalcCustomCVForceKernel : public CommonCalcCustomCVForceKernel { } }; +/** + * This kernel is invoked by ATMForce to calculate the forces acting on the system and the energy of the system. + */ +class HipCalcATMForceKernel : public CommonCalcATMForceKernel { +public: + HipCalcATMForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CommonCalcATMForceKernel(name, platform, cc) { + } + ComputeContext& getInnerComputeContext(ContextImpl& innerContext) { + return *reinterpret_cast(innerContext.getPlatformData())->contexts[0]; + } +}; + } // namespace OpenMM #endif /*OPENMM_HIPKERNELS_H_*/ diff --git a/platforms/hip/src/HipKernelFactory.cpp b/platforms/hip/src/HipKernelFactory.cpp index 9ef58a8..cc71294 100644 --- a/platforms/hip/src/HipKernelFactory.cpp +++ b/platforms/hip/src/HipKernelFactory.cpp @@ -141,5 +141,7 @@ KernelImpl* HipKernelFactory::createKernelImpl(std::string name, const Platform& return new CommonApplyMonteCarloBarostatKernel(name, platform, cu); if (name == RemoveCMMotionKernel::Name()) return new CommonRemoveCMMotionKernel(name, platform, cu); + if (name == CalcATMForceKernel::Name() ) + return new HipCalcATMForceKernel(name, platform, cu); throw OpenMMException((std::string("Tried to create kernel with illegal kernel name '")+name+"'").c_str()); } diff --git a/platforms/hip/src/HipPlatform.cpp b/platforms/hip/src/HipPlatform.cpp index 4ac15a9..c72264f 100644 --- a/platforms/hip/src/HipPlatform.cpp +++ b/platforms/hip/src/HipPlatform.cpp @@ -108,6 +108,7 @@ HipPlatform::HipPlatform() { registerKernelFactory(ApplyAndersenThermostatKernel::Name(), factory); registerKernelFactory(ApplyMonteCarloBarostatKernel::Name(), factory); registerKernelFactory(RemoveCMMotionKernel::Name(), factory); + registerKernelFactory(CalcATMForceKernel::Name(), factory); platformProperties.push_back(HipDeviceIndex()); platformProperties.push_back(HipDeviceName()); platformProperties.push_back(HipUseBlockingSync());