Skip to content

Commit 51bbf63

Browse files
Chilleepytorchmergebot
authored andcommitted
Improved legalize_graph pass in FX (#82874)
Pull Request resolved: #82874 Approved by: https://github.com/jamesr66a
1 parent 4f255db commit 51bbf63

File tree

2 files changed

+31
-40
lines changed

2 files changed

+31
-40
lines changed

functorch/functorch/_src/partitioners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,12 @@ def classify_nodes(joint_module):
277277

278278
pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501
279279
if compiler == "inductor":
280-
pointwise_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone] # noqa: E501
280+
pointwise_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy] # noqa: E501
281281
misc_ops = [aten.to, aten.type_as, operator.getitem]
282282

283283
reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501
284284
if compiler == "inductor":
285-
reduction_ops += [prims.var, prims.sum, aten.var]
285+
reduction_ops += [prims.var, prims.sum, aten.var, aten.std]
286286

287287
# not recomputed by default since these are kinda expensive/hard to fuse into
288288
# norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501

torch/fx/passes/tools_common.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Tuple, Union, Dict, Any, Set, Mapping
2+
import collections
23
from dataclasses import dataclass
34

45
import torch
@@ -209,7 +210,7 @@ def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
209210

210211

211212
@compatibility(is_backward_compatible=False)
212-
def legalize_graph(gm: torch.fx.GraphModule):
213+
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
213214
"""
214215
Replace the graph of the given GraphModule with one that contains the same nodes as the
215216
original, but in topologically sorted order.
@@ -220,43 +221,33 @@ def legalize_graph(gm: torch.fx.GraphModule):
220221
Arguments:
221222
gm: The graph module to topologically sort. It is modified in-place.
222223
224+
Returns:
225+
The graph module in-place sorted
223226
"""
224-
# Build an adjacency list representation of node dependencies in the graph. This also
225-
# serves as a list of nodes that still need to be inserted into the new, topologically
226-
# sorted graph.
227-
dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes}
228-
229-
# Construct a new graph that will contain all nodes in topologically sorted order.
227+
indeg = {node: 0 for node in gm.graph.nodes}
230228
new_graph = torch.fx.Graph()
231-
value_remap: Dict[torch.fx.Node, torch.fx.Node] = {}
232-
233-
# Copy over all nodes with no dependencies.
234-
for node, deps in dependencies.items():
235-
if not deps:
236-
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
237-
238-
# Remove the copied over nodes from the adjacency list.
239-
for copied_node in value_remap.keys():
240-
del dependencies[copied_node]
241-
242-
# While there are still nodes to insert into the new graph:
243-
while dependencies:
244-
copied_this_round = []
245-
246-
# Copy over all nodes whose dependencies already exist in the new graph.
247-
for node, deps in dependencies.items():
248-
all_deps_copied = True
249-
for dep in deps:
250-
if dep not in value_remap:
251-
all_deps_copied = False
252-
253-
if all_deps_copied:
254-
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
255-
copied_this_round.append(node)
256-
257-
# Delete all nodes copied over in this iteration from dependencies.
258-
for copied_node in copied_this_round:
259-
del dependencies[copied_node]
260-
261-
# Replace the old graph with the new, topologically sorted one.
229+
# Track how many unfulfilled dependencies each node has
230+
for node in gm.graph.nodes:
231+
for user in node.users:
232+
indeg[user] += 1
233+
queue: collections.deque = collections.deque()
234+
# Add all nodes with no dependencies to the queue
235+
for node in gm.graph.nodes:
236+
if indeg[node] == 0:
237+
queue.append(node)
238+
env: Dict[torch.fx.Node, torch.fx.Node] = {}
239+
# Pop nodes from the queue, and add nodes that have had all their
240+
# dependencies fulfilled
241+
while len(queue) > 0:
242+
cur = queue.popleft()
243+
env[cur] = new_graph.node_copy(cur, lambda x: env[x])
244+
for user in cur.users:
245+
indeg[user] -= 1
246+
if indeg[user] == 0:
247+
queue.append(user)
248+
# If the new graph's size is not as large as the old one, then there must be
249+
# a cycle (i.e. some node's dependencies were not satisfied.)
250+
if len(new_graph.nodes) < len(gm.graph.nodes):
251+
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
262252
gm.graph = new_graph
253+
return gm

0 commit comments

Comments
 (0)