Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
734b992
Add schema parser and related functionality for Torch compatibility
gouzil Feb 16, 2026
9a0654e
support keyword specification
gouzil Feb 16, 2026
26b319d
clean code
gouzil Feb 16, 2026
15b3ffc
fix hash_combine
gouzil Feb 17, 2026
642371a
add license
gouzil Feb 17, 2026
1feb656
torch_library_test move use WITH_GPU
gouzil Feb 17, 2026
2be5bb2
add `Duplicate keyword` test
gouzil Feb 18, 2026
c191ce0
2025 -> 2026
gouzil Feb 21, 2026
d91263a
add license
gouzil Feb 21, 2026
02ef81b
add more test
gouzil Feb 21, 2026
028b22a
add more test
gouzil Feb 22, 2026
18ffdb0
add more test
gouzil Feb 22, 2026
cda5bb5
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Feb 22, 2026
9dc6724
add more test
gouzil Feb 23, 2026
6676b1c
fix kwargs mapping error
gouzil Feb 24, 2026
8cec652
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Feb 25, 2026
eb4d421
support `Type.isSubtypeOf`、`SchemaArgType` and add torch schema
gouzil Mar 2, 2026
b961da6
add license
gouzil Mar 2, 2026
92f5fd2
clean include
gouzil Mar 2, 2026
474e437
fix: add cmakelist
gouzil Mar 2, 2026
bb0be80
ignore test
gouzil Mar 2, 2026
fafb22c
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Mar 16, 2026
7c9cb61
fix review
gouzil Mar 16, 2026
f2b07b7
fix the AliasInfo hash function to support unordered containers and a…
gouzil Mar 22, 2026
ff036b2
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Mar 22, 2026
ac42022
clean glog
gouzil Mar 24, 2026
d10cf91
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Mar 28, 2026
594330d
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Apr 4, 2026
f620ef7
Modify the header comment format of the file and remove the redundant…
gouzil Apr 4, 2026
f4936a9
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Apr 4, 2026
5cad663
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Apr 5, 2026
6e22c7c
Refactor license comments in compatibility headers and source files a…
gouzil Apr 6, 2026
b5e17a0
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil Apr 7, 2026
d49a95a
[Cpp API Compatibility] Update FunctionSchema references to c10::Func…
gouzil Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions paddle/fluid/pybind/torch_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,11 @@ inline FunctionArgs OperationInvoker::convert_args_kwargs_to_function_args(
function_args.add_arg(std::move(value));
}

for (auto item : kwargs) {
py::str key = item.first.cast<py::str>();
py::object value_obj = item.second.cast<py::object>();

torch::IValue value = to_ivalue(value_obj);
function_args.add_arg(std::move(value));
for (const auto& item : kwargs) {
std::string key = py::cast<std::string>(item.first);
torch::arg keyword(std::move(key));
keyword = to_ivalue(item.second);
function_args.add_arg(std::move(keyword));
}

return function_args;
Expand Down Expand Up @@ -248,12 +247,18 @@ class CustomClassProxyInstance {
if (ClassRegistry::instance().has_method(qualified_name_, method_name)) {
return py::cpp_function(
[this, method_name](py::args args, py::kwargs kwargs) -> py::object {
FunctionArgs converted =
OperationInvoker::convert_args_kwargs_to_function_args(args,
kwargs);
FunctionArgs function_args;
function_args.add_arg(instance_); // this pointer
for (auto arg :
OperationInvoker::convert_args_kwargs_to_function_args(
args, kwargs)) {
function_args.add_arg(std::move(arg));
for (size_t i = 0; i < converted.size(); ++i) {
function_args.add_arg(converted.get_value(i));
}
for (const auto& [name, value] : converted.named_args()) {
torch::arg keyword(name);
keyword = value;
function_args.add_arg(std::move(keyword));
}

auto result = ClassRegistry::instance().call_method_with_args(
Expand Down
156 changes: 156 additions & 0 deletions paddle/phi/api/include/compat/ATen/core/alias_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 torch 的代码记得添加声明

// #The file has been adapted from pytorch project
// #Licensed under  BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

参考 ivalue.h

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 #78590 改一下声明

另外,paddlecodec 那边目前卡在哪里了?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,paddlecodec 那边目前卡在哪里了?

c++ 部分的还有一些 torch::tensor 不支持的,python 部分正在

PFCCLab/paddlecodec#3

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要全部恢复,这不是这个 PR 需要考虑的问题,只需要考虑这个 PR 需要考虑的部分即可

这个 PR 合入之后 PFCCLab/paddlecodec#3 就可以合入的话,那我觉得这个 PR 目前已经达到合入标准了

PFCCLab/paddlecodec#3 依赖 #78521 吗?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gouzil 调试呢,通知处理这块可能会有些抖动,一些信息可能重复处理


// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

#pragma once

#include <ostream>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

namespace c10 {
/**
* class AliasInfo
*
* Data structure to hold aliasing information for an `Argument`. They can be
* nested to represent aliasing information on contained types.
*
* There is a `beforeSet` which describes the aliasing information before the
* operator executes, and an `afterSet` that describes aliasing info
* after execution.
*/
class AliasInfo {
public:
AliasInfo() = default;
AliasInfo(bool is_write,
const std::set<std::string>& before_qual_strings,
const std::set<std::string>& after_qual_strings)
: isWrite_(is_write) {
for (const auto& s : before_qual_strings) {
beforeSets_.insert(s);
}
for (const auto& s : after_qual_strings) {
afterSets_.insert(s);
}
}

bool isWrite() const { return isWrite_; }

const std::unordered_set<std::string>& beforeSets() const {
return beforeSets_;
}

const std::unordered_set<std::string>& afterSets() const {
return afterSets_;
}

// the alias info for the contained types of the type
// e.g. if this is an annotation on List[T], `sets` refers to
// the alias sets that the list may be in
// while containedTypes()[0] refers to the sets that members of the list
// may be in
void addContainedType(AliasInfo aliasInfo) {
containedTypes_.push_back(std::move(aliasInfo));
}
const std::vector<AliasInfo>& containedTypes() const {
return containedTypes_;
}

private:
std::unordered_set<std::string> beforeSets_;
std::unordered_set<std::string> afterSets_;
std::vector<AliasInfo> containedTypes_;
bool isWrite_ = false;
};

inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
return lhs.isWrite() == rhs.isWrite() &&
lhs.beforeSets() == rhs.beforeSets() &&
lhs.afterSets() == rhs.afterSets() &&
lhs.containedTypes() == rhs.containedTypes();
}

// this does match the way things are represented in the schema
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
out << '(';
bool first = true;
for (const auto& set : aliasInfo.beforeSets()) {
if (first) {
first = false;
} else {
out << '|';
}
out << set;
}
if (aliasInfo.isWrite()) {
out << '!';
}
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
out << " -> ";
first = true;
for (const auto& set : aliasInfo.afterSets()) {
if (first) {
first = false;
} else {
out << '|';
}
out << set;
}
}
out << ')';
return out;
}
} // namespace c10

inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}

namespace std {
template <>
struct hash<c10::AliasInfo> {
size_t operator()(const c10::AliasInfo& aliasInfo) const {
auto hash = std::hash<bool>()(aliasInfo.isWrite());

// NOTE: for unordered_set hashes, we couldn't use hash_combine
// because hash_combine is order dependent. Instead, we choose to
// use XOR as the combining function as XOR is commutative.
size_t before_set_hash_seed = 0;
for (auto& e : aliasInfo.beforeSets()) {
auto symbol_hash = std::hash<std::string>()(e);
before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
}
size_t after_set_hash_seed = 0;
for (auto& e : aliasInfo.afterSets()) {
auto symbol_hash = std::hash<std::string>()(e);
after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
}

hash = hash_combine(hash, before_set_hash_seed);
hash = hash_combine(hash, after_set_hash_seed);
for (auto& e : aliasInfo.containedTypes()) {
auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
hash = hash_combine(hash, contained_type_hash);
}
return hash;
}
};
} // namespace std
201 changes: 201 additions & 0 deletions paddle/phi/api/include/compat/ATen/core/function_schema.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有这个

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有一些其他的文件

// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

#include "ATen/core/function_schema.h"

namespace c10 {

namespace {

constexpr char kWildcardAliasSet[] = "*";

const char* schemaArgTypeName(SchemaArgType type) {
if (type == SchemaArgType::input) {
return "input";
}
if (type == SchemaArgType::output) {
return "output";
}
return "unknown";
}

bool aliasSetsMayOverlap(const std::unordered_set<std::string>& lhs,
const std::unordered_set<std::string>& rhs) {
if (lhs.empty() || rhs.empty()) {
return false;
}
if (lhs.count(kWildcardAliasSet) > 0 || rhs.count(kWildcardAliasSet) > 0) {
return true;
}
for (const auto& set : lhs) {
if (rhs.count(set) > 0) {
return true;
}
}
return false;
}

const Argument& getSchemaArgumentOrThrow(const FunctionSchema& schema,
const SchemaArgument& argument) {
const auto& args = schema.getCorrectList(argument);
TORCH_CHECK(argument.index < args.size(),
"Schema ",
schemaArgTypeName(argument.type),
" index ",
argument.index,
" is out of bounds for size ",
args.size());
return args.at(argument.index);
}

bool aliasInfoMayContainAlias(const AliasInfo& lhs,
const AliasInfo& rhs,
bool bidirectional) {
if (aliasSetsMayOverlap(lhs.afterSets(), rhs.afterSets())) {
return true;
}

for (const auto& child : lhs.containedTypes()) {
if (aliasInfoMayContainAlias(child, rhs, /*bidirectional=*/true)) {
return true;
}
}

if (!bidirectional) {
return false;
}
for (const auto& child : rhs.containedTypes()) {
if (aliasInfoMayContainAlias(lhs, child, /*bidirectional=*/true)) {
return true;
}
}
return false;
}

} // namespace

std::ostream& operator<<(std::ostream& out, const Argument& arg) {
out << arg.type()->str() << " " << arg.name();
if (arg.default_value()) {
out << " = " << arg.default_value();
}
return out;
}

std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
out << "(";
bool first = true;
for (const auto& arg : schema.arguments()) {
if (!first) {
out << ", ";
}
out << arg;
first = false;
}
if (schema.is_vararg()) {
if (!first) {
out << ", ";
}
out << "...";
}
out << ")";

out << " -> ";

if (schema.returns().size() == 1) {
out << schema.returns()[0];
} else {
out << "(";
first = true;
for (const auto& ret : schema.returns()) {
if (!first) {
out << ", ";
}
out << ret;
first = false;
}
out << ")";
}

return out;
}

std::optional<int> FunctionSchema::argumentIndexWithName(
const std::string& name) const {
for (size_t i = 0; i < arguments_.size(); ++i) {
if (arguments_[i].name() == name) {
return static_cast<int>(i);
}
}
return std::nullopt;
}

const std::vector<Argument>& FunctionSchema::getCorrectList(
const SchemaArgument& argument) const {
if (argument.type == SchemaArgType::input) {
return arguments();
}
if (argument.type == SchemaArgType::output) {
return returns();
}
TORCH_INTERNAL_ASSERT(false, "Could not match argument type");
}

bool FunctionSchema::is_aliasing(const SchemaArgument& argument) const {
const auto& arg = getSchemaArgumentOrThrow(*this, argument);
return arg.alias_info() != nullptr;
}

bool FunctionSchema::is_mutable(const SchemaArgument& argument) const {
const auto& arg = getSchemaArgumentOrThrow(*this, argument);
return arg.alias_info() != nullptr && arg.alias_info()->isWrite();
}

bool FunctionSchema::is_mutable(const std::string& name) const {
const auto index = argumentIndexWithName(name);
TORCH_CHECK(
index.has_value(), "Tried to test mutability of nonexistent name ", name);
return is_mutable({SchemaArgType::input, static_cast<size_t>(*index)});
}

bool FunctionSchema::may_alias(const SchemaArgument& lhs,
const SchemaArgument& rhs) const {
const auto& lhs_arg = getSchemaArgumentOrThrow(*this, lhs);
const auto& rhs_arg = getSchemaArgumentOrThrow(*this, rhs);
const auto* lhs_alias = lhs_arg.alias_info();
const auto* rhs_alias = rhs_arg.alias_info();
if (lhs_alias == nullptr || rhs_alias == nullptr) {
return false;
}
return aliasSetsMayOverlap(lhs_alias->afterSets(), rhs_alias->afterSets());
}

bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs,
const SchemaArgument& rhs,
bool bidirectional) const {
const auto& lhs_arg = getSchemaArgumentOrThrow(*this, lhs);
const auto& rhs_arg = getSchemaArgumentOrThrow(*this, rhs);
const auto* lhs_alias = lhs_arg.alias_info();
const auto* rhs_alias = rhs_arg.alias_info();
if (lhs_alias == nullptr || rhs_alias == nullptr) {
return false;
}
return aliasInfoMayContainAlias(*lhs_alias, *rhs_alias, bidirectional);
}

} // namespace c10
Loading
Loading