1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
#![allow(non_snake_case)]
use crate::r#impl as notcuda;
use crate::r#impl::CUresult;
use crate::{cuda::CUuuid, r#impl::Encuda};
use ::std::{
ffi::c_void,
os::raw::{c_int, c_uint},
};
use cuda_driver_sys as cuda;
#[macro_export]
macro_rules! cuda_driver_test {
($func:ident) => {
paste! {
#[test]
fn [<$func _notcuda>]() {
$func::<crate::r#impl::test::NotCuda>()
}
#[test]
fn [<$func _cuda>]() {
$func::<crate::r#impl::test::Cuda>()
}
}
};
}
pub trait CudaDriverFns {
fn cuInit(flags: c_uint) -> CUresult;
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult;
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult;
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult;
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult;
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult;
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
}
pub struct NotCuda();
impl CudaDriverFns for NotCuda {
fn cuInit(_flags: c_uint) -> CUresult {
assert!(notcuda::context::is_context_stack_empty());
notcuda::init().encuda()
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda()
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
notcuda::context::destroy_v2(ctx as *mut _)
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::pop_current_v2(pctx as *mut _)
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
notcuda::context::get_api_version(ctx as *mut _, version)
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::get_current(pctx as *mut _).encuda()
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
notcuda::memory::alloc_v2(dptr as *mut _, bytesize)
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda()
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
}
}
pub struct Cuda();
impl CudaDriverFns for Cuda {
fn cuInit(flags: c_uint) -> CUresult {
unsafe { CUresult(cuda::cuInit(flags) as c_uint) }
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) }
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) }
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) }
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) }
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) }
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) }
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) }
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
}
}
|