diff --git a/CHANGELOG.md b/CHANGELOG.md index 7803165ff..3ff218b11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,14 +12,12 @@ - select from *scripts -> pixelsmith* - [Hunyuan Video](https://github.com/Tencent/HunyuanVideo) LoRA support - example: -- ZLUDA v3.8.6 - - new runtime compiler implementation: complex types, JIT are now available +- ZLUDA v3.8.7 + - new runtime compiler implementation: complex types, JIT are now available + - fast fourier transformation is implemented - experimental BLASLt support via nightly build - set `ZLUDA_NIGHTLY=1` to install nightly ZLUDA: newer torch such as 2.4.x (default) and 2.5.x are now available - requirements: unofficial hipBLASLt - - **Important** - - the support for older version of ZLUDA will be removed in next release - - see [upgrading ZLUDA](https://github.com/vladmandic/automatic/wiki/ZLUDA#upgrading-zluda) for more information - **Logging**: - reverted enable debug by default - updated [debug wiki](https://github.com/vladmandic/automatic/wiki/debug) diff --git a/installer.py b/installer.py index 49bcf8b77..5b6e4455a 100644 --- a/installer.py +++ b/installer.py @@ -588,7 +588,7 @@ def install_rocm_zluda(): from modules import zluda_installer zluda_installer.set_default_agent(device) try: - if args.reinstall: + if args.reinstall or zluda_installer.is_old_zluda(): zluda_installer.uninstall() zluda_installer.install() except Exception as e: diff --git a/modules/zluda_hijacks.py b/modules/zluda_hijacks.py index 01ca4dec8..872127cf1 100644 --- a/modules/zluda_hijacks.py +++ b/modules/zluda_hijacks.py @@ -1,5 +1,5 @@ import torch -from modules import zluda_installer, rocm +from modules import rocm _topk = torch.topk @@ -9,32 +9,6 @@ def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-buil return torch.return_types.topk((values.to(device), indices.to(device),)) -_fft_fftn = torch.fft.fftn -def fft_fftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_fftn(input.cpu(), *args, **kwargs).to(input.device) - - -_fft_ifftn = torch.fft.ifftn -def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device) - - -_fft_rfftn = torch.fft.rfftn -def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin - return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device) - - -def jit_script(f, *_, **__): # experiment / provide dummy graph - f.graph = torch._C.Graph() # pylint: disable=protected-access - return f - - def do_hijack(): torch.version.hip = rocm.version torch.topk = topk - torch.fft.fftn = fft_fftn - torch.fft.ifftn = fft_ifftn - torch.fft.rfftn = fft_rfftn - - if not zluda_installer.get_blaslt_enabled(): - torch.jit.script = jit_script diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index d6f4e51a5..8888a56f4 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -12,9 +12,11 @@ DLL_MAPPING = { 'cublas.dll': 'cublas64_11.dll', 'cusparse.dll': 'cusparse64_11.dll', + 'cufft.dll': 'cufft64_10.dll', + 'cufftw.dll': 'cufftw64_10.dll', 'nvrtc.dll': 'nvrtc64_112_0.dll', } -HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll'] +HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',] ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',) path = os.path.abspath(os.environ.get('ZLUDA', '.zluda')) @@ -28,12 +30,16 @@ def set_default_agent(agent: rocm.Agent): default_agent = agent +def is_old_zluda() -> bool: # ZLUDA<3.8.7 + return not os.path.exists(os.path.join(path, "cufftw.dll")) + + def install() -> None: if os.path.exists(path): return platform = "windows" - commit = os.environ.get("ZLUDA_HASH", "d60bddbc870827566b3d2d417e00e1d2d8acc026") + commit = os.environ.get("ZLUDA_HASH", "c4994b3093e02231339d22e12be08418b2af781f") if nightly: platform = "nightly-" + platform urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-{platform}-rocm{rocm.version[0]}-amd64.zip', '_zluda') @@ -80,11 +86,6 @@ def load() -> None: os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1" os.environ["ZLUDA_NVRTC_LIB"] = os.path.join([v for v in site.getsitepackages() if v.endswith("site-packages")][0], "torch", "lib", "nvrtc64_112_0.dll") - try: # for compatibility. will be removed in next release - ctypes.windll.LoadLibrary(f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll') - except Exception: - pass - for v in HIPSDK_TARGETS: ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v)) for v in ZLUDA_TARGETS: