Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ppe.compile Enable forward only and custom decompositions #740

Merged
merged 9 commits into from
Dec 6, 2023

Conversation

emcastillo
Copy link
Contributor

In some cases we want to generate a forward only graph using the aot lightweight tracing.
Also it may be necessary to avoid using default decompositions to prevent some operators to be decomposed and losing important information.

else f"input_{i - len(self._parameter_names)}"
)

# Remove the outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want some detailed comments about the 'output_node' that is being deleted here. I believe it's an intermediate variable for gradient calculation, but is it sufficient to just delete one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comments explaining it!

node for node in fwd_graph.graph.nodes if node.op == "output"
][0]
outputs = pytree.tree_flatten(output_node.args)[0]
fwd_graph.graph.erase_node(output_node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we could verify via tests whether or not the intended node has been deleted. However, if it proves difficult, I don't think there's a need to force its addition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah! thats a great idea! thanks!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check in the test!

@emcastillo
Copy link
Contributor Author

/test

@linshokaku linshokaku self-requested a review December 6, 2023 05:42
Comment on lines +200 to +202
for n, b in module.named_buffers():
parameters_and_buffers.append(b)
names.append(n)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for i_node in primal_inputs:
bwd_outs.append(
bwd_graph.call_function(
torch.ones, (i_node.meta.get("tensor_meta").shape,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
It seems more natural to output zeros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Copy link
Member

@linshokaku linshokaku left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@linshokaku
Copy link
Member

/test

@linshokaku linshokaku enabled auto-merge December 6, 2023 07:57
@linshokaku linshokaku merged commit 73082f6 into pfnet:master Dec 6, 2023
6 checks passed
@linshokaku linshokaku added this to the v0.7.5 milestone Dec 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants