-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ppe.compile
Enable forward only and custom decompositions
#740
Conversation
85aeb01
to
0a8f2e6
Compare
else f"input_{i - len(self._parameter_names)}" | ||
) | ||
|
||
# Remove the outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want some detailed comments about the 'output_node' that is being deleted here. I believe it's an intermediate variable for gradient calculation, but is it sufficient to just delete one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comments explaining it!
node for node in fwd_graph.graph.nodes if node.op == "output" | ||
][0] | ||
outputs = pytree.tree_flatten(output_node.args)[0] | ||
fwd_graph.graph.erase_node(output_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if we could verify via tests whether or not the intended node has been deleted. However, if it proves difficult, I don't think there's a need to force its addition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah! thats a great idea! thanks!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a check in the test!
dd4ae46
to
b60e7b5
Compare
cba2d6a
to
b710645
Compare
/test |
for n, b in module.named_buffers(): | ||
parameters_and_buffers.append(b) | ||
names.append(n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i_node in primal_inputs: | ||
bwd_outs.append( | ||
bwd_graph.call_function( | ||
torch.ones, (i_node.meta.get("tensor_meta").shape,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
It seems more natural to output zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/test |
In some cases we want to generate a forward only graph using the aot lightweight tracing.
Also it may be necessary to avoid using default decompositions to prevent some operators to be decomposed and losing important information.