Skip to content

Commit 541e8fa

Browse files
Add from_v1.
1 parent 80f4efb commit 541e8fa

13 files changed

Lines changed: 165 additions & 43 deletions

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
#include "utils/bidict/algorithms/bidict_from_enumerating.h"
77
#include "utils/containers/map_values.h"
88
#include "utils/containers/transform.h"
9+
#include "utils/graph/digraph/algorithms/get_topological_ordering.h"
10+
#include "utils/graph/digraph/digraph.h"
11+
#include "utils/graph/digraph/directed_edge.dtg.h"
12+
#include "utils/graph/instances/adjacency_digraph.h"
13+
#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h"
914
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h"
15+
#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h"
16+
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h"
1017
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"
1118
#include "utils/graph/node/algorithms.h"
1219

@@ -50,6 +57,56 @@ V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> to_v1(
5057
return to_v1_including_node_numbering(g).first;
5158
}
5259

60+
template <typename NodeLabel, typename OutputLabel, typename SlotName>
61+
LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> from_v1(
62+
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const &v1) {
63+
// Build incoming-edge map
64+
std::unordered_map<nonnegative_int, std::vector<V1GraphEdge<SlotName>>>
65+
incoming;
66+
for (nonnegative_int const &n : v1.graph.nodes) {
67+
incoming[n] = {};
68+
}
69+
for (V1GraphEdge<SlotName> const &e : v1.graph.edges) {
70+
incoming[e.dstNode].push_back(e);
71+
}
72+
73+
// Build a DiGraph with V1 indices as Node raw_uids to get topological order
74+
DiGraph dg = DiGraph::create<AdjacencyDiGraph>();
75+
for (nonnegative_int const &n : v1.graph.nodes) {
76+
dg.add_node_unsafe(Node{static_cast<size_t>(n.unwrap_nonnegative())});
77+
}
78+
for (V1GraphEdge<SlotName> const &e : v1.graph.edges) {
79+
dg.add_edge(DirectedEdge{
80+
Node{static_cast<size_t>(e.srcNode.unwrap_nonnegative())},
81+
Node{static_cast<size_t>(e.dstNode.unwrap_nonnegative())}});
82+
}
83+
84+
auto g = LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>::
85+
template create<UnorderedSetLabelledOpenKwargDataflowGraph<NodeLabel,
86+
OutputLabel,
87+
int,
88+
SlotName>>();
89+
90+
std::unordered_map<nonnegative_int, Node> node_map;
91+
for (Node const &topo_node : get_topological_ordering(dg)) {
92+
nonnegative_int v1_idx{static_cast<size_t>(topo_node.raw_uid)};
93+
94+
std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> inputs;
95+
for (V1GraphEdge<SlotName> const &e : incoming.at(v1_idx)) {
96+
inputs.emplace(
97+
e.dstSlot,
98+
KwargDataflowOutput<SlotName>{node_map.at(e.srcNode), e.srcSlot});
99+
}
100+
101+
KwargNodeAddedResult<SlotName> result = g.add_node(
102+
v1.node_labels.at(v1_idx), inputs, v1.output_labels.at(v1_idx));
103+
104+
node_map.emplace(v1_idx, result.node);
105+
}
106+
107+
return g;
108+
}
109+
53110
} // namespace FlexFlow
54111

55112
#endif

lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
namespace FlexFlow {
99

1010
V1ComputationGraph to_v1(ComputationGraph const &);
11+
ComputationGraph from_v1(V1ComputationGraph const &);
1112

1213
std::pair<V1ComputationGraph, bidict<nonnegative_int, layer_guid_t>>
1314
to_v1_including_node_numbering(ComputationGraph const &);

lib/pcg/include/pcg/file_format/v1/v1_mapped_operator_task_group.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
namespace FlexFlow {
88

99
V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &);
10+
MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &);
1011

1112
} // namespace FlexFlow
1213

lib/pcg/include/pcg/file_format/v1/v1_mapped_parallel_computation_graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
namespace FlexFlow {
88

99
V1MappedParallelComputationGraph to_v1(MappedParallelComputationGraph const &);
10+
MappedParallelComputationGraph
11+
from_v1(V1MappedParallelComputationGraph const &);
1012

1113
} // namespace FlexFlow
1214

lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
namespace FlexFlow {
88

99
V1ParallelComputationGraph to_v1(ParallelComputationGraph const &);
10+
ParallelComputationGraph from_v1(V1ParallelComputationGraph const &);
1011

1112
} // namespace FlexFlow
1213

lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@ template std::pair<
1818
template V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> to_v1(
1919
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> const &);
2020

21+
template LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> from_v1(
22+
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const &);
23+
2124
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ V1ComputationGraph to_v1(ComputationGraph const &g) {
1010
};
1111
}
1212

13+
ComputationGraph from_v1(V1ComputationGraph const &v1) {
14+
return ComputationGraph{
15+
from_v1(v1.raw_graph),
16+
};
17+
}
18+
1319
std::pair<V1ComputationGraph, bidict<nonnegative_int, layer_guid_t>>
1420
to_v1_including_node_numbering(ComputationGraph const &cg) {
1521
std::pair<

lib/pcg/src/pcg/file_format/v1/v1_mapped_operator_task_group.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ V1MappedOperatorTaskGroup to_v1(MappedOperatorTaskGroup const &g) {
66
return V1MappedOperatorTaskGroup{g.get_shard_bindings()};
77
}
88

9+
MappedOperatorTaskGroup from_v1(V1MappedOperatorTaskGroup const &v1) {
10+
return MappedOperatorTaskGroup{v1.shard_bindings};
11+
}
12+
913
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/v1_mapped_parallel_computation_graph.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,13 @@ V1MappedParallelComputationGraph
1414
};
1515
}
1616

17+
MappedParallelComputationGraph
18+
from_v1(V1MappedParallelComputationGraph const &v1) {
19+
return MappedParallelComputationGraph{
20+
from_v1(v1.pcg),
21+
map_values(v1.mapped_tasks,
22+
[](V1MappedOperatorTaskGroup const &g) { return from_v1(g); }),
23+
};
24+
}
25+
1726
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) {
1010
};
1111
}
1212

13+
ParallelComputationGraph from_v1(V1ParallelComputationGraph const &v1) {
14+
return ParallelComputationGraph{
15+
from_v1(v1.raw_graph),
16+
};
17+
}
18+
1319
} // namespace FlexFlow

0 commit comments

Comments
 (0)