-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_custom_op.py
153 lines (117 loc) · 4.32 KB
/
test_custom_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
print(torch.__version__)
import torch._dynamo
import torch.fx
import torch._dynamo.backends.inductor
import torch._inductor.config
torch._inductor.config.trace.enabled = True
from verify_custom_cpp_ops import my_sigmoid
def print_backend(gm: torch.fx.GraphModule, input):
gm.graph.print_tabular()
return gm
def pattern(x):
return torch.sigmoid(x)
# def replacement(x):
# return my_sigmoid(x)
replacement = torch.fx.symbolic_trace(my_sigmoid)
replacement.graph.print_tabular()
def replace_sigmoid(gm: torch.fx.GraphModule):
graph = gm.graph
for node in graph.nodes:
if node.op == 'call_function' and node.target == torch.sigmoid:
with graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = graph.call_function(
my_sigmoid, args=tuple(node.all_input_nodes))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
graph.erase_node(node)
graph.lint()
gm.recompile()
return gm
def replace_pattern_backend(gm: torch.fx.GraphModule, input):
gm.graph.print_tabular()
gm = replace_sigmoid(gm)
gm.graph.print_tabular()
return gm
def replace_pattern_backend_with_inductor(gm: torch.fx.GraphModule, input):
gm = replace_sigmoid(gm)
from torch._inductor.compile_fx import compile_fx
gm.graph.print_tabular()
optimized_forward = compile_fx(gm, example_inputs_=input)
return optimized_forward
def func(x):
x = torch.nn.functional.relu(x)
x = torch.add(x, x)
x = torch.sigmoid(x)
x = torch.nn.functional.gelu(x)
x = x * x
return x
def func_with_custom_op(x):
x = torch.nn.functional.relu(x)
x = torch.add(x, x)
x = my_sigmoid(x)
x = torch.nn.functional.gelu(x)
x = x * x
return x
def copy_tensor(tensor):
another_tensor = torch.rand_like(tensor, requires_grad=tensor.requires_grad)
another_tensor.data.copy_(tensor.data)
return another_tensor
print('-' * 10)
print('Ground truth')
input_tensor = torch.randn(10, device='cuda', requires_grad=True)
output = func(input_tensor)
output.sum().backward()
print(output)
print(input_tensor.grad)
# print how dynamo handles custom op. It turns out that dynamo simply break the graph when encountering custom op
print('-' * 10)
print('Dynamo on graph with custom op without compile')
func_with_custom_triton_op_print = torch.compile(func_with_custom_op, backend=print_backend)
input_tensor_dynamo = copy_tensor(input_tensor)
output = func_with_custom_triton_op_print(input_tensor_dynamo)
output.sum().backward()
print(output)
print(input_tensor_dynamo.grad)
print('-' * 10)
print('Inductor on graph without custom op')
func_with_custom_triton_op_print = torch.compile(func_with_custom_op, backend=print_backend)
input_tensor_inductor = copy_tensor(input_tensor)
output = func_with_custom_triton_op_print(input_tensor_inductor)
output.sum().backward()
print(output)
print(input_tensor_inductor.grad)
print('-' * 10)
print('Compile graph with custom op')
# compile manual op replacement with inductor backend
torch._dynamo.reset()
func_with_custom_triton_op_inductor = torch.compile(func_with_custom_op)
input_tensor1 = copy_tensor(input_tensor)
output = func_with_custom_triton_op_inductor(input_tensor1)
output.sum().backward()
print(output)
print(input_tensor1.grad)
print('-' * 10)
print('Print graph with replaced op')
# graph replacement backend. Make sure pattern is replaced
torch._dynamo.reset()
func_with_custom_triton_op_inductor_ = torch.compile(func, backend=replace_pattern_backend)
input_tensor2 = copy_tensor(input_tensor)
output = func_with_custom_triton_op_inductor_(input_tensor2)
print(output)
output.sum().backward()
print(input_tensor2.grad)
# graph replacement with inductor backend.
print('-' * 10)
print('Replacement with inductor')
torch._dynamo.reset()
input_tensor3 = copy_tensor(input_tensor)
func_with_custom_triton_op_inductor__ = torch.compile(func, backend=replace_pattern_backend_with_inductor)
print(input_tensor3.requires_grad)
output = func_with_custom_triton_op_inductor__(input_tensor3)
print(output)
output.sum().backward()
print(input_tensor3.grad)