Skip to content

Commit

Permalink
Clean up ZLUDA redirection helper
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Feb 4, 2022
1 parent 2753d95 commit 164c172
Showing 1 changed file with 61 additions and 89 deletions.
150 changes: 61 additions & 89 deletions zluda_redirect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
extern crate detours_sys;
extern crate winapi;

use std::{
collections::HashMap,
ffi::{c_void, CStr},
mem, ptr, slice, usize,
};
use std::{ffi::c_void, mem, ptr, slice, usize};

use detours_sys::{
DetourAttach, DetourRestoreAfterWith, DetourTransactionAbort, DetourTransactionBegin,
Expand All @@ -18,6 +14,7 @@ use winapi::{
shared::minwindef::{BOOL, LPVOID},
um::{
handleapi::{CloseHandle, INVALID_HANDLE_VALUE},
libloaderapi::GetModuleFileNameW,
minwinbase::LPSECURITY_ATTRIBUTES,
processthreadsapi::{
CreateProcessA, GetCurrentProcessId, GetCurrentThreadId, OpenThread, ResumeThread,
Expand All @@ -32,15 +29,12 @@ use winapi::{
};
use winapi::{
shared::minwindef::{DWORD, FALSE, HMODULE, TRUE},
um::{
libloaderapi::{GetModuleHandleA, LoadLibraryExA},
winnt::LPCSTR,
},
um::{libloaderapi::LoadLibraryExA, winnt::LPCSTR},
};
use winapi::{
shared::minwindef::{FARPROC, HINSTANCE},
um::{
libloaderapi::{GetModuleFileNameA, GetProcAddress},
libloaderapi::GetProcAddress,
processthreadsapi::{CreateProcessAsUserW, CreateProcessW},
winbase::{CreateProcessWithLogonW, CreateProcessWithTokenW},
winnt::{DLL_PROCESS_ATTACH, DLL_PROCESS_DETACH, HANDLE, LPCWSTR},
Expand Down Expand Up @@ -158,15 +152,6 @@ unsafe extern "system" fn ZludaGetProcAddress_NoRedirect(
hModule: HMODULE,
lpProcName: LPCSTR,
) -> FARPROC {
if let Some(detour_guard) = &DETOUR_STATE {
if hModule != ptr::null_mut() && detour_guard.nvcuda_module == hModule {
let proc_name = CStr::from_ptr(lpProcName);
return match detour_guard.overriden_cuda_fns.get(proc_name) {
Some((original_fn, _)) => mem::transmute::<*mut c_void, _>(*original_fn),
None => ptr::null_mut(),
};
}
}
GetProcAddress(hModule, lpProcName)
}

Expand Down Expand Up @@ -384,8 +369,6 @@ struct DetourDetachGuard {
suspended_threads: Vec<*mut c_void>,
// First element is the original fn, second is the new fn
overriden_non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
nvcuda_module: HMODULE,
overriden_cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
}

impl DetourDetachGuard {
Expand All @@ -394,17 +377,11 @@ impl DetourDetachGuard {
// first element in the pair, because somehow otherwise original functions
// also get overriden, so for example ZludaLoadLibraryExW ends calling
// itself recursively until stack overflow exception occurs
unsafe fn detour_functions<'a>(
nvcuda_module: HMODULE,
non_cuda_fns: Vec<(*mut *mut c_void, *mut c_void)>,
cuda_fns: HashMap<&'static CStr, (*mut c_void, *mut c_void)>,
) -> Option<Self> {
unsafe fn new<'a>() -> Option<Self> {
let mut result = DetourDetachGuard {
state: DetourUndoState::DoNothing,
suspended_threads: Vec::new(),
overriden_non_cuda_fns: non_cuda_fns,
nvcuda_module,
overriden_cuda_fns: cuda_fns,
overriden_non_cuda_fns: Vec::new(),
};
if DetourTransactionBegin() != NO_ERROR as i32 {
return None;
Expand All @@ -419,6 +396,19 @@ impl DetourDetachGuard {
}
}
result.overriden_non_cuda_fns.extend_from_slice(&[
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
ZludaLoadLibraryExA as _,
),
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
(
&mut CREATE_PROCESS_A as *mut _ as _,
ZludaCreateProcessA as _,
Expand All @@ -440,12 +430,7 @@ impl DetourDetachGuard {
ZludaCreateProcessWithTokenW as _,
),
]);
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied().chain(
result
.overriden_cuda_fns
.values_mut()
.map(|(original_ptr, new_ptr)| (original_ptr as *mut _, *new_ptr)),
) {
for (original_fn, new_fn) in result.overriden_non_cuda_fns.iter().copied() {
if DetourAttach(original_fn, new_fn) != NO_ERROR as i32 {
return None;
}
Expand Down Expand Up @@ -659,23 +644,10 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
if DetourRestoreAfterWith() == FALSE {
return FALSE;
}
if !initialize_current_module_name(instDLL) {
if !initialize_globals(instDLL) {
return FALSE;
}
match get_zluda_dlls_paths() {
Some((nvcuda_path, nvml_path)) => {
ZLUDA_PATH_UTF8 = Some(nvcuda_path);
ZLUDA_ML_PATH_UTF8 = Some(nvml_path);
ZLUDA_PATH_UTF16 = std::str::from_utf8_unchecked(nvcuda_path)
.encode_utf16()
.collect::<Vec<_>>();
ZLUDA_ML_PATH_UTF16 = std::str::from_utf8_unchecked(nvml_path)
.encode_utf16()
.collect::<Vec<_>>();
}
None => return FALSE,
}
match detour_already_loaded_nvcuda() {
match DetourDetachGuard::new() {
Some(g) => {
DETOUR_STATE = Some(g);
TRUE
Expand All @@ -692,55 +664,55 @@ unsafe extern "system" fn DllMain(instDLL: HINSTANCE, dwReason: u32, _: *const u
}
}

#[must_use]
unsafe fn initialize_current_module_name(current_module: HINSTANCE) -> bool {
let mut name = vec![0; 128 as usize];
unsafe fn initialize_globals(current_module: HINSTANCE) -> bool {
let mut module_name = vec![0; 128 as usize];
loop {
let size = GetModuleFileNameA(
let size = GetModuleFileNameW(
current_module,
name.as_mut_ptr() as *mut _,
name.len() as u32,
module_name.as_mut_ptr(),
module_name.len() as u32,
);
if size == 0 {
return false;
}
if size < name.len() as u32 {
name.truncate(size as usize);
CURRENT_MODULE_FILENAME = name;
return true;
if size < module_name.len() as u32 {
module_name.truncate(size as usize);
module_name.push(0);
CURRENT_MODULE_FILENAME = String::from_utf16_lossy(&module_name).into_bytes();
break;
}
name.resize(name.len() * 2, 0);
module_name.resize(module_name.len() * 2, 0);
}
if !load_global_string(
&PAYLOAD_NVML_GUID,
&mut ZLUDA_ML_PATH_UTF8,
&mut ZLUDA_ML_PATH_UTF16,
) {
return false;
}
if !load_global_string(
&PAYLOAD_NVCUDA_GUID,
&mut ZLUDA_PATH_UTF8,
&mut ZLUDA_PATH_UTF16,
) {
return false;
}
true
}

#[must_use]
unsafe fn detour_already_loaded_nvcuda() -> Option<DetourDetachGuard> {
let nvcuda_mod = GetModuleHandleA(b"nvcuda\0".as_ptr() as _);
let detour_functions = vec![
(
&mut LOAD_LIBRARY_A as *mut _ as *mut *mut c_void,
ZludaLoadLibraryA as *mut c_void,
),
(&mut LOAD_LIBRARY_W as *mut _ as _, ZludaLoadLibraryW as _),
(
&mut LOAD_LIBRARY_EX_A as *mut _ as _,
ZludaLoadLibraryExA as _,
),
(
&mut LOAD_LIBRARY_EX_W as *mut _ as _,
ZludaLoadLibraryExW as _,
),
];
DetourDetachGuard::detour_functions(nvcuda_mod, detour_functions, HashMap::new())
}

fn get_zluda_dlls_paths() -> Option<(&'static [u8], &'static [u8])> {
match get_payload(&PAYLOAD_NVCUDA_GUID) {
None => None,
Some(nvcuda_payload) => match get_payload(&PAYLOAD_NVML_GUID) {
None => return None,
Some(nvml_payload) => return Some((nvcuda_payload, nvml_payload)),
},
fn load_global_string(
guid: &detours_sys::GUID,
utf8_path: &mut Option<&'static [u8]>,
utf16_path: &mut Vec<u16>,
) -> bool {
if let Some(payload) = get_payload(guid) {
*utf8_path = Some(payload);
*utf16_path = unsafe { std::str::from_utf8_unchecked(payload) }
.encode_utf16()
.collect::<Vec<_>>();
true
} else {
false
}
}

Expand Down

0 comments on commit 164c172

Please sign in to comment.