aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src')
-rw-r--r--zluda/src/cuda.rs8
-rw-r--r--zluda/src/impl/context.rs13
-rw-r--r--zluda/src/impl/dark_api.rs30
-rw-r--r--zluda/src/impl/device.rs2
-rw-r--r--zluda/src/impl/memory.rs12
-rw-r--r--zluda/src/impl/texobj.rs4
-rw-r--r--zluda/src/impl/texref.rs3
7 files changed, 58 insertions, 14 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs
index c16a751..eebf6e9 100644
--- a/zluda/src/cuda.rs
+++ b/zluda/src/cuda.rs
@@ -89,6 +89,7 @@ cuda_function_declarations!(
cuModuleGetTexRef,
cuMemGetInfo_v2,
cuMemAlloc_v2,
+ cuMemAllocHost_v2,
cuMemAllocManaged,
cuMemAllocPitch_v2,
cuMemFree_v2,
@@ -633,6 +634,13 @@ mod definitions {
memory::alloc(dptr, bytesize)
}
+ pub(crate) unsafe fn cuMemAllocHost_v2(
+ pp: *mut *mut ::std::os::raw::c_void,
+ bytesize: usize,
+ ) -> hipError_t {
+ hipMemAllocHost(pp, bytesize)
+ }
+
pub(crate) unsafe fn cuMemAllocManaged(
dev_ptr: *mut hipDeviceptr_t,
size: usize,
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index d1b3e7b..ab2dbfc 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -92,6 +92,9 @@ impl ContextData {
let mut primary_ctx_data = mutex_over_primary_ctx_data
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ if primary_ctx_data.ref_count == 0 {
+ return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED);
+ }
fn_(&mut primary_ctx_data.mutable)
}
ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => {
@@ -104,6 +107,7 @@ impl ContextData {
}
pub(crate) struct ContextInnerMutable {
+ pub(crate) allocations: FxHashSet<*mut c_void>,
pub(crate) streams: FxHashSet<*mut stream::Stream>,
pub(crate) modules: FxHashSet<*mut module::Module>,
// Field below is here to support CUDA Driver Dark API
@@ -113,6 +117,7 @@ pub(crate) struct ContextInnerMutable {
impl ContextInnerMutable {
pub(crate) fn new() -> Self {
ContextInnerMutable {
+ allocations: FxHashSet::default(),
streams: FxHashSet::default(),
modules: FxHashSet::default(),
local_storage: FxHashMap::default(),
@@ -240,7 +245,13 @@ pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Re
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
- //let ctx = LiveCheck::as_result(ctx)?;
+ let ctx = LiveCheck::as_result(ctx)?;
+ if let ContextVariant::Primary(ref primary) = ctx.variant {
+ let primary = primary.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ if primary.ref_count == 0 {
+ return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
+ }
+ }
//TODO: query device for properties roughly matching CUDA API version
*version = 3020;
Ok(())
diff --git a/zluda/src/impl/dark_api.rs b/zluda/src/impl/dark_api.rs
index 08ffa17..aa23f97 100644
--- a/zluda/src/impl/dark_api.rs
+++ b/zluda/src/impl/dark_api.rs
@@ -62,6 +62,27 @@ impl CudaDarkApi for CudaDarkApiZluda {
device::primary_ctx_get(pctx, hip_dev).into_cuda()
}
+ unsafe extern "system" fn primary_context_create_with_flags(
+ dev: CUdevice,
+ flags: u32,
+ ) -> CUresult {
+ unsafe fn primary_context_create_with_flags_impl(
+ dev: CUdevice,
+ flags: u32,
+ ) -> Result<(), CUresult> {
+ let hip_dev = FromCuda::from_cuda(dev);
+ device::primary_ctx(hip_dev, |ctx, _| {
+ if ctx.ref_count > 0 {
+ return Err(CUresult::CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE);
+ }
+ ctx.ref_count = 1;
+ ctx.flags = flags;
+ Ok(())
+ })?
+ }
+ primary_context_create_with_flags_impl(dev, flags).into_cuda()
+ }
+
unsafe extern "system" fn get_module_from_cubin_ex1(
module: *mut cuda_types::CUmodule,
fatbinc_wrapper: *const zluda_dark_api::FatbincWrapper,
@@ -439,7 +460,7 @@ impl CudaDarkApi for CudaDarkApiZluda {
unsafe extern "system" fn get_hip_stream(
stream: CUstream,
) -> CudaResult<*const std::os::raw::c_void> {
- let cuda_object: *mut LiveCheck<stream::StreamData> = stream as *mut stream::Stream;
+ let cuda_object = stream as *mut stream::Stream;
stream::as_hip_stream(cuda_object)
.map(|ptr| ptr as *const _)
.into()
@@ -453,13 +474,6 @@ impl CudaDarkApi for CudaDarkApiZluda {
*is_wrapped = 0;
CUresult::CUDA_SUCCESS
}
-
- unsafe extern "system" fn primary_context_create_with_flags(
- dev: CUdevice,
- flags: u32,
- ) -> CUresult {
- todo!()
- }
}
unsafe fn with_context_or_current<T>(
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index c7e8190..3cc5b83 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -513,7 +513,7 @@ unsafe fn primary_ctx_get_or_retain(
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx = primary_ctx(hip_dev, |ctx, raw_ctx| {
- if increment_refcount || ctx.ref_count == 0 {
+ if increment_refcount {
ctx.ref_count += 1;
}
Ok(raw_ctx.cast_mut())
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index 41840b9..d8226e5 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -1,7 +1,7 @@
use super::stream::Stream;
use super::{hipfix, stream};
use crate::hip_call_cuda;
-use crate::r#impl::{memcpy2d_from_cuda, GLOBAL_STATE};
+use crate::r#impl::{context, memcpy2d_from_cuda, GLOBAL_STATE};
use cuda_types::*;
use hip_runtime_sys::*;
use std::{mem, ptr};
@@ -12,8 +12,14 @@ pub(crate) unsafe fn alloc(dptr: *mut hipDeviceptr_t, mut bytesize: usize) -> Re
}
let zero_buffers = GLOBAL_STATE.get()?.zero_buffers;
bytesize = hipfix::alloc_round_up(bytesize);
- let mut ptr = mem::zeroed();
- hip_call_cuda!(hipMalloc(&mut ptr, bytesize));
+ let ptr = context::with_current(|ctx| {
+ ctx.with_inner_mut(|mutable| {
+ let mut ptr = mem::zeroed();
+ hip_call_cuda!(hipMalloc(&mut ptr, bytesize));
+ mutable.allocations.insert(ptr);
+ Ok(ptr)
+ })
+ })???;
if zero_buffers {
hip_call_cuda!(hipMemsetD32(hipDeviceptr_t(ptr), 0, bytesize / 4));
}
diff --git a/zluda/src/impl/texobj.rs b/zluda/src/impl/texobj.rs
index 21eb453..a26918a 100644
--- a/zluda/src/impl/texobj.rs
+++ b/zluda/src/impl/texobj.rs
@@ -14,6 +14,8 @@ pub(crate) unsafe fn create(
return hipError_t::hipErrorInvalidValue;
}
hipfix::array::with_resource_desc(p_res_desc, |p_res_desc| {
- hipTexObjectCreate(p_tex_object, p_res_desc, p_tex_desc, p_res_view_desc)
+ let mut p_tex_desc = *p_tex_desc;
+ p_tex_desc.maxAnisotropy = 0;
+ hipTexObjectCreate(p_tex_object, p_res_desc, &p_tex_desc, p_res_view_desc)
})
}
diff --git a/zluda/src/impl/texref.rs b/zluda/src/impl/texref.rs
index 307b5ba..1984774 100644
--- a/zluda/src/impl/texref.rs
+++ b/zluda/src/impl/texref.rs
@@ -109,6 +109,9 @@ unsafe fn reset(tex_ref: *mut textureReference) -> Result<(), CUresult> {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let mut res_desc = mem::zeroed();
+ if (*tex_ref).textureObject == ptr::null_mut() {
+ return Ok(());
+ }
hip_call_cuda!(hipGetTextureObjectResourceDesc(
&mut res_desc,
(*tex_ref).textureObject