-
Notifications
You must be signed in to change notification settings - Fork 6k
【Hackathon 10th Spring No.2】Add schema parser and related functionality for Torch compatibility #77938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
SigureMo
merged 34 commits into
PaddlePaddle:develop
from
gouzil:feat/add_schema_parser_torch_compatibility
Apr 8, 2026
Merged
【Hackathon 10th Spring No.2】Add schema parser and related functionality for Torch compatibility #77938
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
734b992
Add schema parser and related functionality for Torch compatibility
gouzil 9a0654e
support keyword specification
gouzil 26b319d
clean code
gouzil 15b3ffc
fix hash_combine
gouzil 642371a
add license
gouzil 1feb656
torch_library_test move use WITH_GPU
gouzil 2be5bb2
add `Duplicate keyword` test
gouzil c191ce0
2025 -> 2026
gouzil d91263a
add license
gouzil 02ef81b
add more test
gouzil 028b22a
add more test
gouzil 18ffdb0
add more test
gouzil cda5bb5
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil 9dc6724
add more test
gouzil 6676b1c
fix kwargs mapping error
gouzil 8cec652
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil eb4d421
support `Type.isSubtypeOf`、`SchemaArgType` and add torch schema
gouzil b961da6
add license
gouzil 92f5fd2
clean include
gouzil 474e437
fix: add cmakelist
gouzil bb0be80
ignore test
gouzil fafb22c
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil 7c9cb61
fix review
gouzil f2b07b7
fix the AliasInfo hash function to support unordered containers and a…
gouzil ff036b2
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil ac42022
clean glog
gouzil d10cf91
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil 594330d
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil f620ef7
Modify the header comment format of the file and remove the redundant…
gouzil f4936a9
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil 5cad663
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil 6e22c7c
Refactor license comments in compatibility headers and source files a…
gouzil b5e17a0
Merge branch 'develop' of github.com:gouzil/Paddle into feat/add_sche…
gouzil d49a95a
[Cpp API Compatibility] Update FunctionSchema references to c10::Func…
gouzil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| // Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| // The file has been adapted from pytorch project | ||
| // Licensed under BSD-style license - | ||
| // https://github.com/pytorch/pytorch/blob/main/LICENSE | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <ostream> | ||
| #include <set> | ||
| #include <string> | ||
| #include <unordered_set> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| namespace c10 { | ||
| /** | ||
| * class AliasInfo | ||
| * | ||
| * Data structure to hold aliasing information for an `Argument`. They can be | ||
| * nested to represent aliasing information on contained types. | ||
| * | ||
| * There is a `beforeSet` which describes the aliasing information before the | ||
| * operator executes, and an `afterSet` that describes aliasing info | ||
| * after execution. | ||
| */ | ||
| class AliasInfo { | ||
| public: | ||
| AliasInfo() = default; | ||
| AliasInfo(bool is_write, | ||
| const std::set<std::string>& before_qual_strings, | ||
| const std::set<std::string>& after_qual_strings) | ||
| : isWrite_(is_write) { | ||
| for (const auto& s : before_qual_strings) { | ||
| beforeSets_.insert(s); | ||
| } | ||
| for (const auto& s : after_qual_strings) { | ||
| afterSets_.insert(s); | ||
| } | ||
| } | ||
|
|
||
| bool isWrite() const { return isWrite_; } | ||
|
|
||
| const std::unordered_set<std::string>& beforeSets() const { | ||
| return beforeSets_; | ||
| } | ||
|
|
||
| const std::unordered_set<std::string>& afterSets() const { | ||
| return afterSets_; | ||
| } | ||
|
|
||
| // the alias info for the contained types of the type | ||
| // e.g. if this is an annotation on List[T], `sets` refers to | ||
| // the alias sets that the list may be in | ||
| // while containedTypes()[0] refers to the sets that members of the list | ||
| // may be in | ||
| void addContainedType(AliasInfo aliasInfo) { | ||
| containedTypes_.push_back(std::move(aliasInfo)); | ||
| } | ||
| const std::vector<AliasInfo>& containedTypes() const { | ||
| return containedTypes_; | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_set<std::string> beforeSets_; | ||
| std::unordered_set<std::string> afterSets_; | ||
| std::vector<AliasInfo> containedTypes_; | ||
| bool isWrite_ = false; | ||
| }; | ||
|
|
||
| inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { | ||
| return lhs.isWrite() == rhs.isWrite() && | ||
| lhs.beforeSets() == rhs.beforeSets() && | ||
| lhs.afterSets() == rhs.afterSets() && | ||
| lhs.containedTypes() == rhs.containedTypes(); | ||
| } | ||
|
|
||
| // this does match the way things are represented in the schema | ||
| inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { | ||
| out << '('; | ||
| bool first = true; | ||
| for (const auto& set : aliasInfo.beforeSets()) { | ||
| if (first) { | ||
| first = false; | ||
| } else { | ||
| out << '|'; | ||
| } | ||
| out << set; | ||
| } | ||
| if (aliasInfo.isWrite()) { | ||
| out << '!'; | ||
| } | ||
| if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { | ||
| out << " -> "; | ||
| first = true; | ||
| for (const auto& set : aliasInfo.afterSets()) { | ||
| if (first) { | ||
| first = false; | ||
| } else { | ||
| out << '|'; | ||
| } | ||
| out << set; | ||
| } | ||
| } | ||
| out << ')'; | ||
| return out; | ||
| } | ||
| } // namespace c10 | ||
|
|
||
| inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { | ||
| lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); | ||
| return lhs; | ||
| } | ||
|
|
||
| namespace std { | ||
| template <> | ||
| struct hash<c10::AliasInfo> { | ||
| size_t operator()(const c10::AliasInfo& aliasInfo) const { | ||
| auto hash = std::hash<bool>()(aliasInfo.isWrite()); | ||
|
|
||
| // NOTE: for unordered_set hashes, we couldn't use hash_combine | ||
| // because hash_combine is order dependent. Instead, we choose to | ||
| // use XOR as the combining function as XOR is commutative. | ||
| size_t before_set_hash_seed = 0; | ||
| for (auto& e : aliasInfo.beforeSets()) { | ||
| auto symbol_hash = std::hash<std::string>()(e); | ||
| before_set_hash_seed = before_set_hash_seed ^ symbol_hash; | ||
| } | ||
| size_t after_set_hash_seed = 0; | ||
| for (auto& e : aliasInfo.afterSets()) { | ||
| auto symbol_hash = std::hash<std::string>()(e); | ||
| after_set_hash_seed = after_set_hash_seed ^ symbol_hash; | ||
| } | ||
|
|
||
| hash = hash_combine(hash, before_set_hash_seed); | ||
| hash = hash_combine(hash, after_set_hash_seed); | ||
| for (auto& e : aliasInfo.containedTypes()) { | ||
| auto contained_type_hash = std::hash<c10::AliasInfo>()(e); | ||
| hash = hash_combine(hash, contained_type_hash); | ||
| } | ||
| return hash; | ||
| } | ||
| }; | ||
| } // namespace std | ||
201 changes: 201 additions & 0 deletions
201
paddle/phi/api/include/compat/ATen/core/function_schema.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| // Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 还有这个
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 还有一些其他的文件 |
||
| // The file has been adapted from pytorch project | ||
| // Licensed under BSD-style license - | ||
| // https://github.com/pytorch/pytorch/blob/main/LICENSE | ||
|
|
||
| #include "ATen/core/function_schema.h" | ||
|
|
||
| namespace c10 { | ||
|
|
||
| namespace { | ||
|
|
||
| constexpr char kWildcardAliasSet[] = "*"; | ||
|
|
||
| const char* schemaArgTypeName(SchemaArgType type) { | ||
| if (type == SchemaArgType::input) { | ||
| return "input"; | ||
| } | ||
| if (type == SchemaArgType::output) { | ||
| return "output"; | ||
| } | ||
| return "unknown"; | ||
| } | ||
|
|
||
| bool aliasSetsMayOverlap(const std::unordered_set<std::string>& lhs, | ||
| const std::unordered_set<std::string>& rhs) { | ||
| if (lhs.empty() || rhs.empty()) { | ||
| return false; | ||
| } | ||
| if (lhs.count(kWildcardAliasSet) > 0 || rhs.count(kWildcardAliasSet) > 0) { | ||
| return true; | ||
| } | ||
| for (const auto& set : lhs) { | ||
| if (rhs.count(set) > 0) { | ||
| return true; | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| const Argument& getSchemaArgumentOrThrow(const FunctionSchema& schema, | ||
| const SchemaArgument& argument) { | ||
| const auto& args = schema.getCorrectList(argument); | ||
| TORCH_CHECK(argument.index < args.size(), | ||
| "Schema ", | ||
| schemaArgTypeName(argument.type), | ||
| " index ", | ||
| argument.index, | ||
| " is out of bounds for size ", | ||
| args.size()); | ||
| return args.at(argument.index); | ||
| } | ||
|
|
||
| bool aliasInfoMayContainAlias(const AliasInfo& lhs, | ||
| const AliasInfo& rhs, | ||
| bool bidirectional) { | ||
| if (aliasSetsMayOverlap(lhs.afterSets(), rhs.afterSets())) { | ||
| return true; | ||
| } | ||
|
|
||
| for (const auto& child : lhs.containedTypes()) { | ||
| if (aliasInfoMayContainAlias(child, rhs, /*bidirectional=*/true)) { | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| if (!bidirectional) { | ||
| return false; | ||
| } | ||
| for (const auto& child : rhs.containedTypes()) { | ||
| if (aliasInfoMayContainAlias(lhs, child, /*bidirectional=*/true)) { | ||
| return true; | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| std::ostream& operator<<(std::ostream& out, const Argument& arg) { | ||
| out << arg.type()->str() << " " << arg.name(); | ||
| if (arg.default_value()) { | ||
| out << " = " << arg.default_value(); | ||
| } | ||
| return out; | ||
| } | ||
|
|
||
| std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { | ||
| out << "("; | ||
| bool first = true; | ||
| for (const auto& arg : schema.arguments()) { | ||
| if (!first) { | ||
| out << ", "; | ||
| } | ||
| out << arg; | ||
| first = false; | ||
| } | ||
| if (schema.is_vararg()) { | ||
| if (!first) { | ||
| out << ", "; | ||
| } | ||
| out << "..."; | ||
| } | ||
| out << ")"; | ||
|
|
||
| out << " -> "; | ||
|
|
||
| if (schema.returns().size() == 1) { | ||
| out << schema.returns()[0]; | ||
| } else { | ||
| out << "("; | ||
| first = true; | ||
| for (const auto& ret : schema.returns()) { | ||
| if (!first) { | ||
| out << ", "; | ||
| } | ||
| out << ret; | ||
| first = false; | ||
| } | ||
| out << ")"; | ||
| } | ||
|
|
||
| return out; | ||
| } | ||
|
|
||
| std::optional<int> FunctionSchema::argumentIndexWithName( | ||
| const std::string& name) const { | ||
| for (size_t i = 0; i < arguments_.size(); ++i) { | ||
| if (arguments_[i].name() == name) { | ||
| return static_cast<int>(i); | ||
| } | ||
| } | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| const std::vector<Argument>& FunctionSchema::getCorrectList( | ||
| const SchemaArgument& argument) const { | ||
| if (argument.type == SchemaArgType::input) { | ||
| return arguments(); | ||
| } | ||
| if (argument.type == SchemaArgType::output) { | ||
| return returns(); | ||
| } | ||
| TORCH_INTERNAL_ASSERT(false, "Could not match argument type"); | ||
| } | ||
|
|
||
| bool FunctionSchema::is_aliasing(const SchemaArgument& argument) const { | ||
| const auto& arg = getSchemaArgumentOrThrow(*this, argument); | ||
| return arg.alias_info() != nullptr; | ||
| } | ||
|
|
||
| bool FunctionSchema::is_mutable(const SchemaArgument& argument) const { | ||
| const auto& arg = getSchemaArgumentOrThrow(*this, argument); | ||
| return arg.alias_info() != nullptr && arg.alias_info()->isWrite(); | ||
| } | ||
|
|
||
| bool FunctionSchema::is_mutable(const std::string& name) const { | ||
| const auto index = argumentIndexWithName(name); | ||
| TORCH_CHECK( | ||
| index.has_value(), "Tried to test mutability of nonexistent name ", name); | ||
| return is_mutable({SchemaArgType::input, static_cast<size_t>(*index)}); | ||
| } | ||
|
|
||
| bool FunctionSchema::may_alias(const SchemaArgument& lhs, | ||
| const SchemaArgument& rhs) const { | ||
| const auto& lhs_arg = getSchemaArgumentOrThrow(*this, lhs); | ||
| const auto& rhs_arg = getSchemaArgumentOrThrow(*this, rhs); | ||
| const auto* lhs_alias = lhs_arg.alias_info(); | ||
| const auto* rhs_alias = rhs_arg.alias_info(); | ||
| if (lhs_alias == nullptr || rhs_alias == nullptr) { | ||
| return false; | ||
| } | ||
| return aliasSetsMayOverlap(lhs_alias->afterSets(), rhs_alias->afterSets()); | ||
| } | ||
|
|
||
| bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, | ||
| const SchemaArgument& rhs, | ||
| bool bidirectional) const { | ||
| const auto& lhs_arg = getSchemaArgumentOrThrow(*this, lhs); | ||
| const auto& rhs_arg = getSchemaArgumentOrThrow(*this, rhs); | ||
| const auto* lhs_alias = lhs_arg.alias_info(); | ||
| const auto* rhs_alias = rhs_arg.alias_info(); | ||
| if (lhs_alias == nullptr || rhs_alias == nullptr) { | ||
| return false; | ||
| } | ||
| return aliasInfoMayContainAlias(*lhs_alias, *rhs_alias, bidirectional); | ||
| } | ||
|
|
||
| } // namespace c10 | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考 torch 的代码记得添加声明
参考
ivalue.hThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考 #78590 改一下声明
另外,paddlecodec 那边目前卡在哪里了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c++ 部分的还有一些
torch::tensor不支持的,python 部分正在PFCCLab/paddlecodec#3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要全部恢复,这不是这个 PR 需要考虑的问题,只需要考虑这个 PR 需要考虑的部分即可
这个 PR 合入之后 PFCCLab/paddlecodec#3 就可以合入的话,那我觉得这个 PR 目前已经达到合入标准了
PFCCLab/paddlecodec#3 依赖 #78521 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有的 PFCCLab/paddlecodec@c348bb8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gouzil 调试呢,通知处理这块可能会有些抖动,一些信息可能重复处理