aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/memory.rs
blob: 5db647271fa3a114be3244c150eed22d690895cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use super::{stream, CUresult, GlobalState};
use std::{ffi::c_void, mem};

pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
    let ptr = GlobalState::lock_current_context(|ctx| {
        let dev = unsafe { &mut *ctx.device };
        Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
    })??;
    unsafe { *dptr = ptr };
    Ok(())
}

pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
    GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
        let cmd_list = stream.command_list()?;
        unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
        stream.queue.execute_and_synchronize(cmd_list)?;
        Ok::<_, CUresult>(())
    })?
}

pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
    GlobalState::lock_current_context(|ctx| {
        let dev = unsafe { &mut *ctx.device };
        Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
    })
    .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}

pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
    GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
        let cmd_list = stream.command_list()?;
        unsafe {
            cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
        }?;
        stream.queue.execute_and_synchronize(cmd_list)?;
        Ok::<_, CUresult>(())
    })?
}

pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
    GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
        let cmd_list = stream.command_list()?;
        unsafe {
            cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
        }?;
        stream.queue.execute_and_synchronize(cmd_list)?;
        Ok::<_, CUresult>(())
    })?
}

#[cfg(test)]
mod test {
    use super::super::test::CudaDriverFns;
    use super::super::CUresult;
    use std::ptr;

    cuda_driver_test!(alloc_without_ctx);

    fn alloc_without_ctx<T: CudaDriverFns>() {
        assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
        let mut mem = ptr::null_mut();
        assert_eq!(
            T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
            CUresult::CUDA_ERROR_INVALID_CONTEXT
        );
        assert_eq!(mem, ptr::null_mut());
    }

    cuda_driver_test!(alloc_with_ctx);

    fn alloc_with_ctx<T: CudaDriverFns>() {
        assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
        let mut ctx = ptr::null_mut();
        assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
        let mut mem = ptr::null_mut();
        assert_eq!(
            T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
            CUresult::CUDA_SUCCESS
        );
        assert_ne!(mem, ptr::null_mut());
        assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
    }

    cuda_driver_test!(free_without_ctx);

    fn free_without_ctx<T: CudaDriverFns>() {
        assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
        let mut ctx = ptr::null_mut();
        assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
        let mut mem = ptr::null_mut();
        assert_eq!(
            T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
            CUresult::CUDA_SUCCESS
        );
        assert_ne!(mem, ptr::null_mut());
        assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
        assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
    }
}