Skip to content

Commit

Permalink
Minor fixes to PT2 export path: enum typo and max_seq_len (pytorch#1343)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Nov 5, 2024
1 parent 9480258 commit 54455a3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
6 changes: 3 additions & 3 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def export_for_server(
dynamic_shapes=dynamic_shapes,
options=options,
)

if package:
from torch._inductor.package import package_aoti
path = package_aoti(output_path, path)

print(f"The generated packaged model can be found at: {path}")
return path

Expand Down Expand Up @@ -382,7 +382,7 @@ def main(args):

if builder_args.max_seq_length is None:
if (
output_dso_path is not None
(output_dso_path is not None or output_aoti_package_path is not None)
and not builder_args.dynamic_shapes
):
print("Setting max_seq_length to 300 for DSO export.")
Expand Down
26 changes: 17 additions & 9 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -77,31 +77,39 @@ def unpack_packed_weights(
def set_backend(dso, pte, aoti_package):
global active_builder_args_dso
global active_builder_args_pte
global active_builder_args_aoti_package
active_builder_args_dso = dso
active_builder_args_aoti_package = aoti_package
active_builder_args_pte = pte


class _Backend(Enum):
AOTI = (0,)
AOTI = 0
EXECUTORCH = 1


def _active_backend() -> _Backend:
def _active_backend() -> Optional[_Backend]:
global active_builder_args_dso
global active_builder_args_aoti_package
global active_builder_args_pte

# eager == aoti, which is when backend has not been explicitly set
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
return True
args = (
active_builder_args_dso,
active_builder_args_pte,
active_builder_args_aoti_package,
)

# Return None, as default
if not any(args):
return None

if active_builder_args_pte and active_builder_args_aoti_package:
# Catch more than one arg
if sum(map(bool, args)) > 1:
raise RuntimeError(
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
"Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!"
)

return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH
return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI


def use_aoti_backend() -> bool:
Expand Down

0 comments on commit 54455a3

Please sign in to comment.