diff options
Diffstat (limited to 'zluda/src/impl/d3d11.rs')
-rw-r--r-- | zluda/src/impl/d3d11.rs | 63 |
1 files changed, 48 insertions, 15 deletions
diff --git a/zluda/src/impl/d3d11.rs b/zluda/src/impl/d3d11.rs index ce69f08..37edd69 100644 --- a/zluda/src/impl/d3d11.rs +++ b/zluda/src/impl/d3d11.rs @@ -1,19 +1,25 @@ -use cl3::d3d11::{CL_D3D11_DXGI_ADAPTER_KHR, CL_PREFERRED_DEVICES_FOR_D3D11_KHR};
+use cl3::d3d11::{
+ CL_CONTEXT_D3D11_DEVICE_KHR, CL_D3D11_DXGI_ADAPTER_KHR, CL_PREFERRED_DEVICES_FOR_D3D11_KHR,
+};
use cl3::device::*;
use cl3::ext::{CL_CONTEXT_PLATFORM, CL_DEVICE_TYPE_GPU, CL_MEM_READ_WRITE};
use cl3::platform::*;
use cl3::{ext::CL_PLATFORM_VENDOR, platform::get_platform_info};
use cuda_types::*;
use hip_runtime_sys::{hipDevice_t, hipGraphicsResource_t};
+use lazy_static::lazy_static;
+use rustc_hash::FxHashMap;
use std::ffi::{c_void, CStr};
+use std::mem::ManuallyDrop;
+use std::sync::Mutex;
use std::{mem, ptr};
use windows::Win32::Graphics::Direct3D11::{
- ID3D11Buffer, ID3D11Resource, ID3D11Texture2D, ID3D11Texture3D,
+ ID3D11Buffer, ID3D11Device, ID3D11Resource, ID3D11Texture2D, ID3D11Texture3D,
};
+use windows::Win32::Graphics::Dxgi::IDXGIAdapter;
static mut PLATFORM: *mut c_void = ptr::null_mut();
static mut DEVICE: *mut c_void = ptr::null_mut();
-static mut CONTEXT: *mut c_void = ptr::null_mut();
pub(crate) unsafe fn get_device(
p_cuda_device: *mut hipDevice_t,
@@ -30,12 +36,14 @@ pub(crate) unsafe fn get_device( .unwrap();
let mut dev = mem::zeroed();
let mut dev_count = mem::zeroed();
+ let p_adapter = p_adapter.cast();
+ let p_adapter = IDXGIAdapter::from_raw_borrowed(&p_adapter).unwrap();
assert_eq!(
0,
clGetDeviceIDsFromD3D11KHR(
PLATFORM,
CL_D3D11_DXGI_ADAPTER_KHR,
- p_adapter.cast(),
+ p_adapter.as_raw(),
CL_PREFERRED_DEVICES_FOR_D3D11_KHR,
1,
&mut dev,
@@ -44,10 +52,6 @@ pub(crate) unsafe fn get_device( );
assert_eq!(dev_count, 1);
DEVICE = dev;
- let properties = [CL_CONTEXT_PLATFORM, PLATFORM as _, 0];
- let context =
- cl3::context::create_context(&[dev], properties.as_ptr(), None, ptr::null_mut()).unwrap();
- CONTEXT = context;
*p_cuda_device = 0;
Ok(())
}
@@ -62,8 +66,10 @@ pub(crate) unsafe fn register_resource( if flags != 0 {
panic!()
}
- let resource = mem::transmute::<_, &ID3D11Resource>(p_d3_dresource);
+ let p_d3_dresource = p_d3_dresource.cast_mut();
+ let resource = ID3D11Resource::from_raw_borrowed(&p_d3_dresource).unwrap();
let mem = if let Ok(buffer) = resource.cast::<ID3D11Buffer>() {
+ let device = buffer.GetDevice().unwrap();
let clCreateFromD3D11BufferKHR =
mem::transmute::<_, cl3::d3d11::clCreateFromD3D11BufferKHR_fn>(
cl3::ext::clGetExtensionFunctionAddressForPlatform(
@@ -74,14 +80,15 @@ pub(crate) unsafe fn register_resource( .unwrap();
let mut err = 0;
let mem = clCreateFromD3D11BufferKHR(
- CONTEXT,
+ get_cl_context(device),
CL_MEM_READ_WRITE,
- (&buffer as *const ID3D11Buffer).cast_mut().cast(),
+ buffer.as_raw(),
&mut err,
);
assert_eq!(err, 0);
mem
} else if let Ok(tex_2d) = resource.cast::<ID3D11Texture2D>() {
+ let device = tex_2d.GetDevice().unwrap();
let clCreateFromD3D11Texture2DKHR =
mem::transmute::<_, cl3::d3d11::clCreateFromD3D11Texture2DKHR_fn>(
cl3::ext::clGetExtensionFunctionAddressForPlatform(
@@ -92,15 +99,16 @@ pub(crate) unsafe fn register_resource( .unwrap();
let mut err = 0;
let mem = clCreateFromD3D11Texture2DKHR(
- CONTEXT,
+ get_cl_context(device),
CL_MEM_READ_WRITE,
- (&tex_2d as *const ID3D11Texture2D).cast_mut().cast(),
+ tex_2d.as_raw(),
0,
&mut err,
);
assert_eq!(err, 0);
mem
} else if let Ok(tex_3d) = resource.cast::<ID3D11Texture3D>() {
+ let device = tex_3d.GetDevice().unwrap();
let clCreateFromD3D11Texture3DKHR =
mem::transmute::<_, cl3::d3d11::clCreateFromD3D11Texture3DKHR_fn>(
cl3::ext::clGetExtensionFunctionAddressForPlatform(
@@ -111,9 +119,9 @@ pub(crate) unsafe fn register_resource( .unwrap();
let mut err = 0;
let mem = clCreateFromD3D11Texture3DKHR(
- CONTEXT,
+ get_cl_context(device),
CL_MEM_READ_WRITE,
- (&tex_3d as *const ID3D11Texture3D).cast_mut().cast(),
+ tex_3d.as_raw(),
0,
&mut err,
);
@@ -149,3 +157,28 @@ unsafe fn initialize_opencl() { }
panic!()
}
+
+lazy_static! {
+ static ref DEVICE_TO_CONTEXT: Mutex<FxHashMap<usize, usize>> = Mutex::new(FxHashMap::default());
+}
+
+unsafe fn get_cl_context(dev: ID3D11Device) -> *mut c_void {
+ let mut map = DEVICE_TO_CONTEXT.lock().unwrap();
+ match map.entry(dev.as_raw() as usize) {
+ std::collections::hash_map::Entry::Occupied(entry) => (*entry.get()) as *mut c_void,
+ std::collections::hash_map::Entry::Vacant(entry) => {
+ let properties = [
+ CL_CONTEXT_PLATFORM,
+ PLATFORM as _,
+ CL_CONTEXT_D3D11_DEVICE_KHR as _,
+ dev.as_raw() as _,
+ 0,
+ ];
+ let context =
+ cl3::context::create_context(&[DEVICE], properties.as_ptr(), None, ptr::null_mut())
+ .unwrap();
+ entry.insert(context as usize);
+ context
+ }
+ }
+}
|