From bb44955fd0c63a07d87bf68fc999c780dfd8b446 Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Mon, 24 Jun 2024 14:22:22 +0900 Subject: [PATCH] zluda vqa florence --- modules/vqa.py | 3 ++- modules/zluda.py | 18 +++--------------- modules/zluda_hijacks.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 16 deletions(-) create mode 100644 modules/zluda_hijacks.py diff --git a/modules/vqa.py b/modules/vqa.py index 5a17c091e..a704f98ec 100644 --- a/modules/vqa.py +++ b/modules/vqa.py @@ -2,6 +2,7 @@ import transformers from PIL import Image from modules import shared, devices +from modules.zluda import is_zluda processor = None @@ -129,7 +130,7 @@ def moondream(question: str, image: Image.Image, repo: str = None): def florence(question: str, image: Image.Image, repo: str = None): global processor, model, loaded # pylint: disable=global-statement from installer import install, installed - if not installed('flash_attn', quiet=True): + if not installed('flash_attn', quiet=True) and not is_zluda(devices.device): install('flash_attn') if model is None or loaded != repo: model = transformers.AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True) diff --git a/modules/zluda.py b/modules/zluda.py index e244c245d..ceae8648f 100644 --- a/modules/zluda.py +++ b/modules/zluda.py @@ -1,4 +1,3 @@ -import os import sys from typing import Union import torch @@ -6,17 +5,13 @@ import onnxruntime as ort from modules import shared, devices from modules.onnx_impl.execution_providers import available_execution_providers, ExecutionProvider +from modules.zluda_hijacks import do_hijack PLATFORM = sys.platform do_nothing = lambda _: None # pylint: disable=unnecessary-lambda-assignment -def _join_rocm_home(*paths) -> str: - from torch.utils.cpp_extension import ROCM_HOME - return os.path.join(ROCM_HOME, *paths) - - def is_zluda(device: DeviceLikeType): try: device = torch.device(device) @@ -42,16 +37,9 @@ def initialize_zluda(): if not devices.cuda_ok or not is_zluda(device): return - torch.version.hip = "5.7" - sys.platform = "" - from torch.utils import cpp_extension - sys.platform = PLATFORM - cpp_extension.IS_WINDOWS = PLATFORM == "win32" - cpp_extension.IS_MACOS = False - cpp_extension.IS_LINUX = sys.platform.startswith('linux') - cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access + do_hijack() - if cpp_extension.IS_WINDOWS: + if PLATFORM == "win32": torch.backends.cudnn.enabled = False torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_flash_sdp = do_nothing diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py new file mode 100644 index 000000000..eef4aab14 --- /dev/null +++ b/modules/zluda_hijacks.py @@ -0,0 +1,28 @@ +import os +import sys +import torch + + +_topk = torch.topk +def topk(tensor: torch.Tensor, *args, **kwargs): + device = tensor.device + values, indices = _topk(tensor.cpu(), *args, **kwargs) + return torch.return_types.topk((values.to(device), indices.to(device),)) + + +def _join_rocm_home(*paths) -> str: + from torch.utils.cpp_extension import ROCM_HOME + return os.path.join(ROCM_HOME, *paths) + + +def do_hijack(): + torch.version.hip = "5.7" + torch.topk = topk + platform = sys.platform + sys.platform = "" + from torch.utils import cpp_extension + sys.platform = platform + cpp_extension.IS_WINDOWS = platform == "win32" + cpp_extension.IS_MACOS = False + cpp_extension.IS_LINUX = platform.startswith('linux') + cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access