Skip to content

Commit

Permalink
Merge pull request #265 from arogozhnikov/autoregister_ops_in_torchdy…
Browse files Browse the repository at this point in the history
…namo

automatically register torch ops in torchdynamo
  • Loading branch information
arogozhnikov authored Jul 8, 2023
2 parents 1f0cf20 + 474af4e commit ef4028c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def __init__(self):
import torch

self.torch = torch
# importing would register operations in torch._dynamo for torch.compile
from . import _torch_specific # noqa

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.torch.Tensor)
Expand Down
15 changes: 14 additions & 1 deletion einops/_torch_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
a number of additional moves is needed.
Design of main operations (dynamic resolution by lookup) is unlikely
to be implemented by torch.jit.script, but torch.compile seems to work completely fine.
to be implemented by torch.jit.script,
but torch.compile seems to work with operations just fine.
"""
import warnings
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -98,10 +99,14 @@ def apply_for_scriptable_torch(


def allow_ops_in_compiled_graph():
if hasattr(torch, "__version__") and torch.__version__[0] < "2":
# torch._dynamo and torch.compile appear in pytorch 2.0
return
try:
from torch._dynamo import allow_in_graph
except ImportError:
warnings.warn("allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning)
return

from .einops import rearrange, reduce, repeat, einsum
from .packing import pack, unpack
Expand All @@ -112,3 +117,11 @@ def allow_ops_in_compiled_graph():
allow_in_graph(einsum)
allow_in_graph(pack)
allow_in_graph(unpack)

# CF: https://github.com/pytorch/pytorch/blob/2df939aacac68e9621fbd5d876c78d86e72b41e2/torch/_dynamo/__init__.py#L222
global _ops_were_registered_in_torchdynamo
_ops_were_registered_in_torchdynamo = True


# module import automatically registers ops in torchdynamo
allow_ops_in_compiled_graph()

0 comments on commit ef4028c

Please sign in to comment.