diff --git a/runtime/cudaq/utils/extension_point.h b/runtime/cudaq/utils/extension_point.h new file mode 100644 index 00000000000..a419c88c08e --- /dev/null +++ b/runtime/cudaq/utils/extension_point.h @@ -0,0 +1,212 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2022-2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +namespace cudaq { + +/// @brief A template class for implementing an extension point mechanism. +/// +/// This class provides a framework for registering and retrieving plugin-like +/// extensions. It allows dynamic creation of objects based on registered types. +/// +/// @tparam T The base type of the extensions. +/// @tparam CtorArgs Variadic template parameters for constructor arguments. +/// +/// How to use the extension_point class +/// +/// The extension_point class provides a mechanism for creating extensible +/// frameworks with plugin-like functionality. Here's how to use it: +/// +/// 1. Define your extension point: +/// Create a new class that inherits from cudaq::extension_point. +/// This class should declare pure virtual methods that extensions will +/// implement. +/// +/// @code +/// class MyExtensionPoint : public cudaq::extension_point { +/// public: +/// virtual std::string parrotBack(const std::string &msg) const = 0; +/// }; +/// @endcode +/// +/// 2. Implement concrete extensions: +/// Create classes that inherit from your extension point and implement its +/// methods. Use the CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION macro to define a +/// creator function. +/// +/// @code +/// class RepeatBackOne : public MyExtensionPoint { +/// public: +/// std::string parrotBack(const std::string &msg) const override { +/// return msg + " from RepeatBackOne."; +/// } +/// +/// CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne) +/// }; +/// @endcode +/// +/// 3. Register your extensions: +/// Use the CUDAQ_REGISTER_EXTENSION macro to register each extension. +/// +/// @code +/// CUDAQ_REGISTER_EXTENSION(RepeatBackOne) +/// @endcode +/// +/// 4. Use your extensions: +/// You can now create instances of your extensions, check registrations, and +/// more. +/// +/// @code +/// auto extension = MyExtensionPoint::get("RepeatBackOne"); +/// std::cout << extension->parrotBack("Hello") << std::endl; +/// +/// auto registeredTypes = MyExtensionPoint::get_registered(); +/// bool isRegistered = MyExtensionPoint::is_registered("RepeatBackOne"); +/// @endcode +/// +/// This approach allows for a flexible, extensible design where new +/// functionality can be added without modifying existing code. +template +class extension_point { + + /// Type alias for the creator function. + using CreatorFunction = std::function(CtorArgs...)>; + +protected: + /// @brief Get the registry of creator functions. + /// @return A reference to the static registry map. + /// See CUDAQ_INSTANTIATE_REGISTRY() macros below for sample implementations + /// that need to be included in C++ source files. + static std::unordered_map &get_registry(); + +public: + /// @brief Create an instance of a registered extension. + /// @param name The identifier of the registered extension. + /// @param args Constructor arguments for the extension. + /// @return A unique pointer to the created instance. + /// @throws std::runtime_error if the extension is not found. + static std::unique_ptr get(const std::string &name, CtorArgs... args) { + auto ®istry = get_registry(); + auto iter = registry.find(name); + if (iter == registry.end()) + throw std::runtime_error("Cannot find extension with name = " + name); + + return iter->second(std::forward(args)...); + } + + /// @brief Get a list of all registered extension names. + /// @return A vector of registered extension names. + static std::vector get_registered() { + std::vector names; + auto ®istry = get_registry(); + for (auto &[k, v] : registry) + names.push_back(k); + return names; + } + + /// @brief Check if an extension is registered. + /// @param name The identifier of the extension to check. + /// @return True if the extension is registered, false otherwise. + static bool is_registered(const std::string &name) { + auto ®istry = get_registry(); + return registry.find(name) != registry.end(); + } + virtual ~extension_point() = default; +}; + +/// @brief Macro for defining a creator function for an extension. +/// @param BASE The base class of the extension. +/// @param TYPE The derived class implementing the extension. +#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(BASE, TYPE) \ + static inline bool register_type() { \ + auto ®istry = get_registry(); \ + registry[TYPE::class_identifier] = TYPE::create; \ + return true; \ + } \ + static const bool registered_; \ + static inline const std::string class_identifier = #TYPE; \ + static std::unique_ptr create() { return std::make_unique(); } + +#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION_WITH_NAME(NAME, BASE, TYPE) \ + static inline bool register_type() { \ + auto ®istry = get_registry(); \ + registry[#NAME] = TYPE::create; \ + return true; \ + } \ + static const bool registered_; \ + static std::unique_ptr create() { return std::make_unique(); } + +/// @brief Macro for defining a custom creator function for an extension. +/// @param TYPE The class implementing the extension. +/// @param ... Custom implementation of the create function. +#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(TYPE, ...) \ + static inline bool register_type() { \ + auto ®istry = get_registry(); \ + registry[TYPE::class_identifier] = TYPE::create; \ + return true; \ + } \ + static const bool registered_; \ + static inline const std::string class_identifier = #TYPE; \ + __VA_ARGS__ + +#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(TYPE, NAME, ...) \ + static inline bool register_type() { \ + auto ®istry = TYPE::get_registry(); \ + registry.insert({NAME, TYPE::create}); \ + return true; \ + } \ + static const bool registered_; \ + static inline const std::string class_identifier = #TYPE; \ + __VA_ARGS__ + +/// @brief Macro for registering an extension type. +/// @param TYPE The class to be registered as an extension. +#define CUDAQ_REGISTER_EXTENSION(TYPE) \ + const bool TYPE::registered_ = TYPE::register_type(); + +/// In order to support building CUDA-Q libraries with g++ and building +/// application code with nvq++ (which uses clang++ under the hood), you must +/// implement the templated get_registry() function for every set of +/// extension_point. This *must* be done in a C++ file that is built +/// with the CUDA-Q libraries. +/// +/// Use this version of the helper macro if the only template argument to +/// extension_point<> is the derived class (with no additional creator args). +#define CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(FULL_TYPE_NAME) \ + template <> \ + std::unordered_map()>> & \ + cudaq::extension_point::get_registry() { \ + static std::unordered_map< \ + std::string, std::function()>> \ + registry; \ + return registry; \ + } + +/// Use this variadic version of the helper macro if there are additional +/// arguments for the creator function. +#define CUDAQ_INSTANTIATE_REGISTRY(FULL_TYPE_NAME, ...) \ + template <> \ + std::unordered_map< \ + std::string, \ + std::function(__VA_ARGS__)>> & \ + cudaq::extension_point::get_registry() { \ + static std::unordered_map< \ + std::string, \ + std::function(__VA_ARGS__)>> \ + registry; \ + return registry; \ + } + +} // namespace cudaq diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 8ff54e165b4..fa60c0821ac 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -493,3 +493,7 @@ if (CUDAQ_ENABLE_PYTHON) gtest_discover_tests(test_domains TEST_SUFFIX _Sampling PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python") endif() + +add_executable(test_extension_point extension/test_extension_point.cpp) +target_link_libraries(test_extension_point PRIVATE GTest::gtest_main cudaq) +gtest_discover_tests(test_extension_point) diff --git a/unittests/extension/test_extension_point.cpp b/unittests/extension/test_extension_point.cpp new file mode 100644 index 00000000000..324d92bea0b --- /dev/null +++ b/unittests/extension/test_extension_point.cpp @@ -0,0 +1,159 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "cudaq/utils/extension_point.h" + +#include + +namespace cudaq::testing { + +// Define a new extension point for the framework +class MyExtensionPoint : public cudaq::extension_point { +public: + virtual std::string parrotBack(const std::string &msg) const = 0; + virtual ~MyExtensionPoint() = default; +}; + +} // namespace cudaq::testing + +CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(cudaq::testing::MyExtensionPoint) + +namespace cudaq::testing { + +// Define a concrete realization of that extension point +class RepeatBackOne : public MyExtensionPoint { +public: + std::string parrotBack(const std::string &msg) const override { + return msg + " from RepeatBackOne."; + } + + // Extension must provide a creator function + CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne) +}; + +// Extensions must register themselves +CUDAQ_REGISTER_EXTENSION(RepeatBackOne) + +class RepeatBackTwo : public MyExtensionPoint { +public: + std::string parrotBack(const std::string &msg) const override { + return msg + " from RepeatBackTwo."; + } + CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackTwo) +}; +CUDAQ_REGISTER_EXTENSION(RepeatBackTwo) + +} // namespace cudaq::testing + +TEST(ExtensionPointTester, checkSimpleExtensionPoint) { + + auto registeredNames = cudaq::testing::MyExtensionPoint::get_registered(); + EXPECT_EQ(registeredNames.size(), 2); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackTwo") != registeredNames.end()); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackOne") != registeredNames.end()); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackThree") == registeredNames.end()); + + { + auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackOne"); + EXPECT_EQ(var->parrotBack("Hello World"), + "Hello World from RepeatBackOne."); + } + { + auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackTwo"); + EXPECT_EQ(var->parrotBack("Hello World"), + "Hello World from RepeatBackTwo."); + } +} + +namespace cudaq::testing { + +class MyExtensionPointWithArgs + : public cudaq::extension_point { +protected: + int i; + double d; + +public: + MyExtensionPointWithArgs(int i, double d) : i(i), d(d) {} + virtual std::tuple parrotBack() const = 0; + virtual ~MyExtensionPointWithArgs() = default; +}; + +} // namespace cudaq::testing + +CUDAQ_INSTANTIATE_REGISTRY(cudaq::testing::MyExtensionPointWithArgs, int, + double) + +namespace cudaq::testing { + +class RepeatBackOneWithArgs : public MyExtensionPointWithArgs { +public: + using MyExtensionPointWithArgs::MyExtensionPointWithArgs; + std::tuple parrotBack() const override { + return std::make_tuple(i, d, "RepeatBackOne"); + } + + CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION( + RepeatBackOneWithArgs, + static std::unique_ptr create(int i, double d) { + return std::make_unique(i, d); + }) +}; + +CUDAQ_REGISTER_EXTENSION(RepeatBackOneWithArgs) + +class RepeatBackTwoWithArgs : public MyExtensionPointWithArgs { +public: + using MyExtensionPointWithArgs::MyExtensionPointWithArgs; + std::tuple parrotBack() const override { + return std::make_tuple(i, d, "RepeatBackTwo"); + } + + CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION( + RepeatBackTwoWithArgs, + static std::unique_ptr create(int i, double d) { + return std::make_unique(i, d); + }) +}; + +CUDAQ_REGISTER_EXTENSION(RepeatBackTwoWithArgs) + +} // namespace cudaq::testing + +TEST(CoreTester, checkSimpleExtensionPointWithArgs) { + + auto registeredNames = + cudaq::testing::MyExtensionPointWithArgs::get_registered(); + EXPECT_EQ(registeredNames.size(), 2); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackTwoWithArgs") != registeredNames.end()); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackOneWithArgs") != registeredNames.end()); + EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(), + "RepeatBackThree") == registeredNames.end()); + + { + auto var = cudaq::testing::MyExtensionPointWithArgs::get( + "RepeatBackOneWithArgs", 5, 2.2); + auto [i, d, msg] = var->parrotBack(); + EXPECT_EQ(msg, "RepeatBackOne"); + EXPECT_EQ(i, 5); + EXPECT_NEAR(d, 2.2, 1e-2); + } + { + auto var = cudaq::testing::MyExtensionPointWithArgs::get( + "RepeatBackTwoWithArgs", 15, 12.2); + auto [i, d, msg] = var->parrotBack(); + EXPECT_EQ(msg, "RepeatBackTwo"); + EXPECT_EQ(i, 15); + EXPECT_NEAR(d, 12.2, 1e-2); + } +}