Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions runtime/cudaq/utils/extension_point.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/****************************************************************-*- C++ -*-****
* Copyright (c) 2022-2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#pragma once

#include <functional>
#include <memory>
#include <stdexcept>
#include <unordered_map>

namespace cudaq {

/// @brief A template class for implementing an extension point mechanism.
///
/// This class provides a framework for registering and retrieving plugin-like
/// extensions. It allows dynamic creation of objects based on registered types.
///
/// @tparam T The base type of the extensions.
/// @tparam CtorArgs Variadic template parameters for constructor arguments.
///
/// How to use the extension_point class
///
/// The extension_point class provides a mechanism for creating extensible
/// frameworks with plugin-like functionality. Here's how to use it:
///
/// 1. Define your extension point:
/// Create a new class that inherits from cudaq::extension_point<YourClass>.
/// This class should declare pure virtual methods that extensions will
/// implement.
///
/// @code
/// class MyExtensionPoint : public cudaq::extension_point<MyExtensionPoint> {
/// public:
/// virtual std::string parrotBack(const std::string &msg) const = 0;
/// };
/// @endcode
///
/// 2. Implement concrete extensions:
/// Create classes that inherit from your extension point and implement its
/// methods. Use the CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION macro to define a
/// creator function.
///
/// @code
/// class RepeatBackOne : public MyExtensionPoint {
/// public:
/// std::string parrotBack(const std::string &msg) const override {
/// return msg + " from RepeatBackOne.";
/// }
///
/// CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne)
/// };
/// @endcode
///
/// 3. Register your extensions:
/// Use the CUDAQ_REGISTER_EXTENSION macro to register each extension.
///
/// @code
/// CUDAQ_REGISTER_EXTENSION(RepeatBackOne)
/// @endcode
///
/// 4. Use your extensions:
/// You can now create instances of your extensions, check registrations, and
/// more.
///
/// @code
/// auto extension = MyExtensionPoint::get("RepeatBackOne");
/// std::cout << extension->parrotBack("Hello") << std::endl;
///
/// auto registeredTypes = MyExtensionPoint::get_registered();
/// bool isRegistered = MyExtensionPoint::is_registered("RepeatBackOne");
/// @endcode
///
/// This approach allows for a flexible, extensible design where new
/// functionality can be added without modifying existing code.
template <typename T, typename... CtorArgs>
class extension_point {

/// Type alias for the creator function.
using CreatorFunction = std::function<std::unique_ptr<T>(CtorArgs...)>;

protected:
/// @brief Get the registry of creator functions.
/// @return A reference to the static registry map.
/// See CUDAQ_INSTANTIATE_REGISTRY() macros below for sample implementations
/// that need to be included in C++ source files.
static std::unordered_map<std::string, CreatorFunction> &get_registry();

public:
/// @brief Create an instance of a registered extension.
/// @param name The identifier of the registered extension.
/// @param args Constructor arguments for the extension.
/// @return A unique pointer to the created instance.
/// @throws std::runtime_error if the extension is not found.
static std::unique_ptr<T> get(const std::string &name, CtorArgs... args) {
auto &registry = get_registry();
auto iter = registry.find(name);
if (iter == registry.end())
throw std::runtime_error("Cannot find extension with name = " + name);

return iter->second(std::forward<CtorArgs>(args)...);
}

/// @brief Get a list of all registered extension names.
/// @return A vector of registered extension names.
static std::vector<std::string> get_registered() {
std::vector<std::string> names;
auto &registry = get_registry();
for (auto &[k, v] : registry)
names.push_back(k);
return names;
}

/// @brief Check if an extension is registered.
/// @param name The identifier of the extension to check.
/// @return True if the extension is registered, false otherwise.
static bool is_registered(const std::string &name) {
auto &registry = get_registry();
return registry.find(name) != registry.end();
}
virtual ~extension_point() = default;
};

/// @brief Macro for defining a creator function for an extension.
/// @param BASE The base class of the extension.
/// @param TYPE The derived class implementing the extension.
#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(BASE, TYPE) \
static inline bool register_type() { \
auto &registry = get_registry(); \
registry[TYPE::class_identifier] = TYPE::create; \
return true; \
} \
static const bool registered_; \
static inline const std::string class_identifier = #TYPE; \
static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); }

#define CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION_WITH_NAME(NAME, BASE, TYPE) \
static inline bool register_type() { \
auto &registry = get_registry(); \
registry[#NAME] = TYPE::create; \
return true; \
} \
static const bool registered_; \
static std::unique_ptr<BASE> create() { return std::make_unique<TYPE>(); }

/// @brief Macro for defining a custom creator function for an extension.
/// @param TYPE The class implementing the extension.
/// @param ... Custom implementation of the create function.
#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(TYPE, ...) \
static inline bool register_type() { \
auto &registry = get_registry(); \
registry[TYPE::class_identifier] = TYPE::create; \
return true; \
} \
static const bool registered_; \
static inline const std::string class_identifier = #TYPE; \
__VA_ARGS__

#define CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION_WITH_NAME(TYPE, NAME, ...) \
static inline bool register_type() { \
auto &registry = TYPE::get_registry(); \
registry.insert({NAME, TYPE::create}); \
return true; \
} \
static const bool registered_; \
static inline const std::string class_identifier = #TYPE; \
__VA_ARGS__

/// @brief Macro for registering an extension type.
/// @param TYPE The class to be registered as an extension.
#define CUDAQ_REGISTER_EXTENSION(TYPE) \
const bool TYPE::registered_ = TYPE::register_type();

/// In order to support building CUDA-Q libraries with g++ and building
/// application code with nvq++ (which uses clang++ under the hood), you must
/// implement the templated get_registry() function for every set of
/// extension_point<Args..>. This *must* be done in a C++ file that is built
/// with the CUDA-Q libraries.
///
/// Use this version of the helper macro if the only template argument to
/// extension_point<> is the derived class (with no additional creator args).
#define CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(FULL_TYPE_NAME) \
template <> \
std::unordered_map<std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>()>> & \
cudaq::extension_point<FULL_TYPE_NAME>::get_registry() { \
static std::unordered_map< \
std::string, std::function<std::unique_ptr<FULL_TYPE_NAME>()>> \
registry; \
return registry; \
}

/// Use this variadic version of the helper macro if there are additional
/// arguments for the creator function.
#define CUDAQ_INSTANTIATE_REGISTRY(FULL_TYPE_NAME, ...) \
template <> \
std::unordered_map< \
std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> & \
cudaq::extension_point<FULL_TYPE_NAME, __VA_ARGS__>::get_registry() { \
static std::unordered_map< \
std::string, \
std::function<std::unique_ptr<FULL_TYPE_NAME>(__VA_ARGS__)>> \
registry; \
return registry; \
}

} // namespace cudaq
4 changes: 4 additions & 0 deletions unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,7 @@ if (CUDAQ_ENABLE_PYTHON)
gtest_discover_tests(test_domains
TEST_SUFFIX _Sampling PROPERTIES ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python")
endif()

add_executable(test_extension_point extension/test_extension_point.cpp)
target_link_libraries(test_extension_point PRIVATE GTest::gtest_main cudaq)
gtest_discover_tests(test_extension_point)
159 changes: 159 additions & 0 deletions unittests/extension/test_extension_point.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/****************************************************************-*- C++ -*-****
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "cudaq/utils/extension_point.h"

#include <gtest/gtest.h>

namespace cudaq::testing {

// Define a new extension point for the framework
class MyExtensionPoint : public cudaq::extension_point<MyExtensionPoint> {
public:
virtual std::string parrotBack(const std::string &msg) const = 0;
virtual ~MyExtensionPoint() = default;
};

} // namespace cudaq::testing

CUDAQ_INSTANTIATE_REGISTRY_NO_ARGS(cudaq::testing::MyExtensionPoint)

namespace cudaq::testing {

// Define a concrete realization of that extension point
class RepeatBackOne : public MyExtensionPoint {
public:
std::string parrotBack(const std::string &msg) const override {
return msg + " from RepeatBackOne.";
}

// Extension must provide a creator function
CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackOne)
};

// Extensions must register themselves
CUDAQ_REGISTER_EXTENSION(RepeatBackOne)

class RepeatBackTwo : public MyExtensionPoint {
public:
std::string parrotBack(const std::string &msg) const override {
return msg + " from RepeatBackTwo.";
}
CUDAQ_ADD_EXTENSION_CREATOR_FUNCTION(MyExtensionPoint, RepeatBackTwo)
};
CUDAQ_REGISTER_EXTENSION(RepeatBackTwo)

} // namespace cudaq::testing

TEST(ExtensionPointTester, checkSimpleExtensionPoint) {

auto registeredNames = cudaq::testing::MyExtensionPoint::get_registered();
EXPECT_EQ(registeredNames.size(), 2);
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackTwo") != registeredNames.end());
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackOne") != registeredNames.end());
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackThree") == registeredNames.end());

{
auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackOne");
EXPECT_EQ(var->parrotBack("Hello World"),
"Hello World from RepeatBackOne.");
}
{
auto var = cudaq::testing::MyExtensionPoint::get("RepeatBackTwo");
EXPECT_EQ(var->parrotBack("Hello World"),
"Hello World from RepeatBackTwo.");
}
}

namespace cudaq::testing {

class MyExtensionPointWithArgs
: public cudaq::extension_point<MyExtensionPointWithArgs, int, double> {
protected:
int i;
double d;

public:
MyExtensionPointWithArgs(int i, double d) : i(i), d(d) {}
virtual std::tuple<int, double, std::string> parrotBack() const = 0;
virtual ~MyExtensionPointWithArgs() = default;
};

} // namespace cudaq::testing

CUDAQ_INSTANTIATE_REGISTRY(cudaq::testing::MyExtensionPointWithArgs, int,
double)

namespace cudaq::testing {

class RepeatBackOneWithArgs : public MyExtensionPointWithArgs {
public:
using MyExtensionPointWithArgs::MyExtensionPointWithArgs;
std::tuple<int, double, std::string> parrotBack() const override {
return std::make_tuple(i, d, "RepeatBackOne");
}

CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(
RepeatBackOneWithArgs,
static std::unique_ptr<MyExtensionPointWithArgs> create(int i, double d) {
return std::make_unique<RepeatBackOneWithArgs>(i, d);
})
};

CUDAQ_REGISTER_EXTENSION(RepeatBackOneWithArgs)

class RepeatBackTwoWithArgs : public MyExtensionPointWithArgs {
public:
using MyExtensionPointWithArgs::MyExtensionPointWithArgs;
std::tuple<int, double, std::string> parrotBack() const override {
return std::make_tuple(i, d, "RepeatBackTwo");
}

CUDAQ_ADD_EXTENSION_CUSTOM_CREATOR_FUNCTION(
RepeatBackTwoWithArgs,
static std::unique_ptr<MyExtensionPointWithArgs> create(int i, double d) {
return std::make_unique<RepeatBackTwoWithArgs>(i, d);
})
};

CUDAQ_REGISTER_EXTENSION(RepeatBackTwoWithArgs)

} // namespace cudaq::testing

TEST(CoreTester, checkSimpleExtensionPointWithArgs) {

auto registeredNames =
cudaq::testing::MyExtensionPointWithArgs::get_registered();
EXPECT_EQ(registeredNames.size(), 2);
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackTwoWithArgs") != registeredNames.end());
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackOneWithArgs") != registeredNames.end());
EXPECT_TRUE(std::find(registeredNames.begin(), registeredNames.end(),
"RepeatBackThree") == registeredNames.end());

{
auto var = cudaq::testing::MyExtensionPointWithArgs::get(
"RepeatBackOneWithArgs", 5, 2.2);
auto [i, d, msg] = var->parrotBack();
EXPECT_EQ(msg, "RepeatBackOne");
EXPECT_EQ(i, 5);
EXPECT_NEAR(d, 2.2, 1e-2);
}
{
auto var = cudaq::testing::MyExtensionPointWithArgs::get(
"RepeatBackTwoWithArgs", 15, 12.2);
auto [i, d, msg] = var->parrotBack();
EXPECT_EQ(msg, "RepeatBackTwo");
EXPECT_EQ(i, 15);
EXPECT_NEAR(d, 12.2, 1e-2);
}
}
Loading