Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from test.util import TVMTest
import torch
import torch_tvm
import torch.nn.functional as F


class TestCore(TVMTest):
Expand Down Expand Up @@ -135,5 +136,32 @@ def dropout_inference(a, b, c):
str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \
"dropout must be removed during inference."

@TVMTest.given(
shape=TVMTest.rand_shape(rank=2, min_dim=4),
out_features=TVMTest.rand_int(3, 6),
)
def test_fuse_single_node(self, shape, out_features):
print("Running test for test_fuse_single_node")
input = torch.rand(shape)
weight = torch.rand(out_features, shape[1])
bias = torch.rand(out_features)

# check single node graph
def linear(a, b, c):
return F.linear(a, b, c)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check if the test fails without your changes?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap, it does indeed crash for single op version. For fusion one, I added to make sure future changes do not break it.


ref_out, tvm_out = self.runBoth(linear, input, weight, bias)
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01)

# check to verify fusion still works
def linearSum(a, b, c):
return F.linear(a, b, c) + 2.0

ref_out, tvm_out = self.runBoth(linearSum, input, weight, bias)
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01)




if __name__ == "__main__":
unittest.main()
26 changes: 21 additions & 5 deletions torch_tvm/fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ bool canHandle(Block* block, AliasDb& aliasDb) {
GRAPH_DEBUG("Failed cond " #cond "\n"); \
return c10::nullopt; \
}

c10::optional<Node*> tryLower(Node* node, AliasDb& aliasDb) {
GRAPH_DEBUG("Trying to lower node ", node->kind().toQualString(), ":\n");
// Already converted so return no change
REQ(node->kind() != getTVMSymbol() && !node->hasAttribute(attr::Subgraph));
REQ(canHandle(node, aliasDb));

if (!aliasDb.isMutable(node)) {
REQ(!aliasDb.hasOutputWriters(node));
}
// proceed to convert current node to TVM
node = SubgraphUtils::createSingletonSubgraph(node, getTVMSymbol());
return node;
}

c10::optional<Node*> tryMerge(
Node* consumer,
Node* producer,
Expand All @@ -61,7 +76,7 @@ c10::optional<Node*> tryMerge(

// Symbolic checks
REQ(canHandle(producer, aliasDb));
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol()));
REQ(consumer->kind() == getTVMSymbol());

// Alias checks
// Requirement:
Expand All @@ -83,10 +98,6 @@ c10::optional<Node*> tryMerge(
}
}

if (!consumer->hasAttribute(attr::Subgraph) &&
consumer->kind() != getTVMSymbol()) {
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol());
}
if (producer->kind() == prim::Constant) {
auto& subgraph = consumer->g(attr::Subgraph);
Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* {
Expand All @@ -107,6 +118,11 @@ std::pair<graph_node_list::iterator, bool> scanNode(
Block* block) {
auto inputs = sortReverseTopological(consumer->inputs(), block);
for (auto input : inputs) {
if(auto group = tryLower(consumer, aliasDb)) {
// we successfully lowered,
// rescan the new group for merging opportunities
return {group.value()->reverseIterator(), true};
}
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
Expand Down