aboutsummaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/test.rs
blob: 0ad625bf175092d5ed656cf2ff13a4232948ba5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#![allow(non_snake_case)]

use crate::r#impl as notcuda;
use crate::r#impl::CUresult;
use crate::{cuda::CUuuid, r#impl::Encuda};
use ::std::{
    ffi::c_void,
    os::raw::{c_int, c_uint},
};
use cuda_driver_sys as cuda;

#[macro_export]
macro_rules! cuda_driver_test {
    ($func:ident) => {
        paste! {
            #[test]
            fn [<$func _notcuda>]() {
                $func::<crate::r#impl::test::NotCuda>()
            }

            #[test]
            fn [<$func _cuda>]() {
                $func::<crate::r#impl::test::Cuda>()
            }
        }
    };
}

pub trait CudaDriverFns {
    fn cuInit(flags: c_uint) -> CUresult;
    fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult;
    fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult;
    fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult;
    fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult;
    fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult;
    fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
    fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
    fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
}

pub struct NotCuda();

impl CudaDriverFns for NotCuda {
    fn cuInit(_flags: c_uint) -> CUresult {
        assert!(notcuda::context::is_context_stack_empty());
        notcuda::init().encuda()
    }

    fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
        notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda()
    }

    fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
        notcuda::context::destroy_v2(ctx as *mut _)
    }

    fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
        notcuda::context::pop_current_v2(pctx as *mut _)
    }

    fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
        notcuda::context::get_api_version(ctx as *mut _, version)
    }

    fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
        notcuda::context::get_current(pctx as *mut _).encuda()
    }
    fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
        notcuda::memory::alloc_v2(dptr as *mut _, bytesize)
    }

    fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
        notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda()
    }

    fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
        notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
    }
}

pub struct Cuda();

impl CudaDriverFns for Cuda {
    fn cuInit(flags: c_uint) -> CUresult {
        unsafe { CUresult(cuda::cuInit(flags) as c_uint) }
    }

    fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
        unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) }
    }

    fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
        unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) }
    }

    fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
        unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) }
    }

    fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
        unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) }
    }

    fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
        unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) }
    }
    fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
        unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) }
    }

    fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
        unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) }
    }

    fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
        unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
    }
}