aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs485
1 files changed, 485 insertions, 0 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
new file mode 100644
index 0000000..57c7156
--- /dev/null
+++ b/cuda_base/src/lib.rs
@@ -0,0 +1,485 @@
+extern crate proc_macro;
+
+use std::collections::hash_map;
+use std::iter;
+
+use proc_macro::TokenStream;
+use proc_macro2::Span;
+use quote::{format_ident, quote, ToTokens};
+use rustc_hash::{FxHashMap, FxHashSet};
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::token::Brace;
+use syn::visit_mut::VisitMut;
+use syn::{
+ bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident,
+ Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments,
+ PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr,
+};
+
+const CUDA_RS: &'static str = include_str! {"cuda.rs"};
+
+// This macro copies cuda.rs as-is with some changes:
+// * All function declarations are filtered out
+// * CUdeviceptr_v2 is redefined from `unsigned long long` to `*void`
+// * `extern "C"` gets replaced by `extern "system"`
+// * CUuuid_st is redefined to use uchar instead of char
+#[proc_macro]
+pub fn cuda_type_declarations(_: TokenStream) -> TokenStream {
+ let mut cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
+ cuda_module.items = cuda_module
+ .items
+ .into_iter()
+ .filter_map(|item| match item {
+ Item::ForeignMod(_) => None,
+ Item::Struct(mut struct_) => {
+ if "CUdeviceptr_v2" == struct_.ident.to_string() {
+ match &mut struct_.fields {
+ Fields::Unnamed(ref mut fields) => {
+ fields.unnamed[0].ty =
+ absolute_path_to_mut_ptr(&["std", "os", "raw", "c_void"])
+ }
+ _ => unreachable!(),
+ }
+ } else if "CUuuid_st" == struct_.ident.to_string() {
+ match &mut struct_.fields {
+ Fields::Named(ref mut fields) => match fields.named[0].ty {
+ Type::Array(TypeArray { ref mut elem, .. }) => {
+ *elem = Box::new(Type::Path(TypePath {
+ qself: None,
+ path: segments_to_path(&["std", "os", "raw", "c_uchar"]),
+ }))
+ }
+ _ => unreachable!(),
+ },
+ _ => panic!(),
+ }
+ }
+ Some(Item::Struct(struct_))
+ }
+ i => Some(i),
+ })
+ .collect::<Vec<_>>();
+ syn::visit_mut::visit_file_mut(&mut FixAbi, &mut cuda_module);
+ cuda_module.into_token_stream().into()
+}
+
+fn segments_to_path(path: &[&'static str]) -> Path {
+ let mut segments = Punctuated::new();
+ for ident in path {
+ let ident = PathSegment {
+ ident: Ident::new(ident, Span::call_site()),
+ arguments: PathArguments::None,
+ };
+ segments.push(ident);
+ }
+ Path {
+ leading_colon: Some(Token![::](Span::call_site())),
+ segments,
+ }
+}
+
+fn absolute_path_to_mut_ptr(path: &[&'static str]) -> Type {
+ Type::Ptr(TypePtr {
+ star_token: Token![*](Span::call_site()),
+ const_token: None,
+ mutability: Some(Token![mut](Span::call_site())),
+ elem: Box::new(Type::Path(TypePath {
+ qself: None,
+ path: segments_to_path(path),
+ })),
+ })
+}
+
+struct FixAbi;
+
+impl VisitMut for FixAbi {
+ fn visit_abi_mut(&mut self, i: &mut Abi) {
+ if let Some(ref mut name) = i.name {
+ *name = LitStr::new("system", Span::call_site());
+ }
+ }
+}
+
+// This macro accepts following arguments:
+// * `type_path`: path to the module with type definitions (in the module tree)
+// * `normal_macro`: ident for a normal macro
+// * `override_macro`: ident for an override macro
+// * `override_fns`: list of override functions
+// Then macro goes through every function in rust.rs, and for every fn `foo`:
+// * if `foo` is contained in `override_fns` then pass it into `override_macro`
+// * if `foo` is not contained in `override_fns` pass it to `normal_macro`
+// Both `override_macro` and `normal_macro` expect this format:
+// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult)
+// Additionally, it does a fixup of CUDA types so they get prefixed with `type_path`
+#[proc_macro]
+pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(tokens as FnDeclInput);
+ let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
+ let override_fns = input
+ .override_fns
+ .iter()
+ .map(ToString::to_string)
+ .collect::<FxHashSet<_>>();
+ cuda_module
+ .items
+ .into_iter()
+ .filter_map(|item| match item {
+ Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() {
+ ForeignItem::Fn(ForeignItemFn {
+ sig:
+ Signature {
+ ident,
+ inputs,
+ output,
+ ..
+ },
+ ..
+ }) => {
+ let path = if override_fns.contains(&ident.to_string()) {
+ &input.override_macro
+ } else {
+ &input.normal_macro
+ }
+ .clone();
+ let inputs = inputs
+ .into_iter()
+ .map(|fn_arg| match fn_arg {
+ FnArg::Typed(mut pat_type) => {
+ pat_type.ty =
+ prepend_cuda_path_to_type(&input.type_path, pat_type.ty);
+ FnArg::Typed(pat_type)
+ }
+ _ => unreachable!(),
+ })
+ .collect::<Punctuated<_, Token![,]>>();
+ let output = match output {
+ ReturnType::Type(_, type_) => type_,
+ ReturnType::Default => unreachable!(),
+ };
+ let type_path = input.type_path.clone();
+ let tokens = quote! {
+ "system" fn #ident(#inputs) -> #type_path :: #output
+ };
+ Some(Item::Macro(ItemMacro {
+ attrs: Vec::new(),
+ ident: None,
+ mac: Macro {
+ path,
+ bang_token: Token![!](Span::call_site()),
+ delimiter: MacroDelimiter::Brace(Brace {
+ span: Span::call_site(),
+ }),
+ tokens,
+ },
+ semi_token: None,
+ }))
+ }
+ _ => unreachable!(),
+ },
+ _ => None,
+ })
+ .map(Item::into_token_stream)
+ .collect::<proc_macro2::TokenStream>()
+ .into()
+}
+
+fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
+ match *type_ {
+ Type::Path(mut type_path) => {
+ type_path.path = prepend_cuda_path_to_path(base_path, type_path.path);
+ Box::new(Type::Path(type_path))
+ }
+ Type::Ptr(mut type_ptr) => {
+ type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem);
+ Box::new(Type::Ptr(type_ptr))
+ }
+ _ => unreachable!(),
+ }
+}
+
+fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
+ if path.leading_colon.is_some() {
+ return path;
+ }
+ if path.segments.len() == 1 {
+ let ident = path.segments[0].ident.to_string();
+ if ident.starts_with("CU") || ident.starts_with("cu") {
+ let mut base_path = base_path.clone();
+ base_path.segments.extend(path.segments);
+ return base_path;
+ }
+ }
+ path
+}
+
+struct FnDeclInput {
+ type_path: Path,
+ normal_macro: Path,
+ override_macro: Path,
+ override_fns: Punctuated<Ident, Token![,]>,
+}
+
+impl Parse for FnDeclInput {
+ fn parse(input: ParseStream) -> syn::Result<Self> {
+ let type_path = input.parse::<Path>()?;
+ input.parse::<Token![,]>()?;
+ let normal_macro = input.parse::<Path>()?;
+ input.parse::<Token![,]>()?;
+ let override_macro = input.parse::<Path>()?;
+ input.parse::<Token![,]>()?;
+ let override_fns_content;
+ bracketed!(override_fns_content in input);
+ let override_fns = override_fns_content.parse_terminated(Ident::parse)?;
+ Ok(Self {
+ type_path,
+ normal_macro,
+ override_macro,
+ override_fns,
+ })
+ }
+}
+
+// This trait accepts following parameters:
+// * `type_path`: path to the module with type definitions (in the module tree)
+// * `trait_`: name of the trait to be derived
+// * `ignore_structs`: bracketed list of types to ignore
+// * `ignore_fns`: bracketed list of fns to ignore
+#[proc_macro]
+pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(tokens as DeriveDisplayInput);
+ let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
+ let mut derive_state = DeriveDisplayState::new(input);
+ cuda_module
+ .items
+ .into_iter()
+ .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i))
+ .collect::<proc_macro2::TokenStream>()
+ .into()
+}
+
+fn cuda_derive_display_trait_for_item(
+ state: &mut DeriveDisplayState,
+ item: Item,
+) -> Option<proc_macro2::TokenStream> {
+ let path_prefix = &state.type_path;
+ let path_prefix_iter = iter::repeat(&path_prefix);
+ let trait_ = &state.trait_;
+ let trait_iter = iter::repeat(&state.trait_);
+ match item {
+ Item::Const(_) => None,
+ Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() {
+ ForeignItem::Fn(ForeignItemFn {
+ sig: Signature { ident, inputs, .. },
+ ..
+ }) => {
+ if state.ignore_fns.contains(&ident) {
+ return None;
+ }
+ let inputs = inputs
+ .into_iter()
+ .map(|fn_arg| match fn_arg {
+ FnArg::Typed(mut pat_type) => {
+ pat_type.ty = prepend_cuda_path_to_type(path_prefix, pat_type.ty);
+ FnArg::Typed(pat_type)
+ }
+ _ => unreachable!(),
+ })
+ .collect::<Vec<_>>();
+ let inputs_iter = inputs.iter();
+ let mut arg_name_iter = inputs.iter().map(|fn_arg| match fn_arg {
+ FnArg::Typed(PatType { pat, .. }) => pat,
+ _ => unreachable!(),
+ });
+ let fn_name = format_ident!("write_{}", ident);
+ Some(match arg_name_iter.next() {
+ Some(first_arg_name) => quote! {
+ pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> {
+ writer.write_all(concat!("(", stringify!(#first_arg_name), ": ").as_bytes())?;
+ CudaDisplay::write(&#first_arg_name, writer)?;
+ #(
+ writer.write_all(b", ")?;
+ writer.write_all(concat!(stringify!(#arg_name_iter), ": ").as_bytes())?;
+ CudaDisplay::write(&#arg_name_iter, writer)?;
+ )*
+ writer.write_all(b")")
+ }
+ },
+ None => quote! {
+ pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ writer.write_all(b"()")
+ }
+ },
+ })
+ }
+ _ => unreachable!(),
+ },
+ Item::Impl(mut item_impl) => {
+ let enum_ = match *(item_impl.self_ty) {
+ Type::Path(mut path) => path.path.segments.pop().unwrap().into_value().ident,
+ _ => unreachable!(),
+ };
+ let variant_ = match item_impl.items.pop().unwrap() {
+ syn::ImplItem::Const(item_const) => item_const.ident,
+ _ => unreachable!(),
+ };
+ state.record_enum_variant(enum_, variant_);
+ None
+ }
+ Item::Struct(item_struct) => {
+ let item_struct_name = item_struct.ident.to_string();
+ if state.ignore_structs.contains(&item_struct.ident) {
+ return None;
+ }
+ if item_struct_name.ends_with("_enum") {
+ let enum_ = &item_struct.ident;
+ let enum_iter = iter::repeat(&item_struct.ident);
+ let variants = state.enums.get(&item_struct.ident).unwrap().iter();
+ Some(quote! {
+ impl #trait_ for #path_prefix :: #enum_ {
+ fn write(&self, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ match self {
+ #(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)*
+ _ => write!(writer, "{}", self.0)
+ }
+ }
+ }
+ })
+ } else {
+ let struct_ = &item_struct.ident;
+ let (first_field, rest_of_fields) = match item_struct.fields {
+ Fields::Named(fields) => {
+ let mut all_idents = fields.named.into_iter().filter_map(|f| {
+ let f_ident = f.ident.unwrap();
+ let name = f_ident.to_string();
+ if name.starts_with("reserved") || name == "_unused" {
+ None
+ } else {
+ Some(f_ident)
+ }
+ });
+ let first = match all_idents.next() {
+ Some(f) => f,
+ None => return None,
+ };
+ (first, all_idents)
+ }
+ _ => return None,
+ };
+ Some(quote! {
+ impl #trait_ for #path_prefix :: #struct_ {
+ fn write(&self, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
+ #trait_::write(&self.#first_field, writer)?;
+ #(
+ writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
+ #trait_iter::write(&self.#rest_of_fields, writer)?;
+ )*
+ writer.write_all(b" }")
+ }
+ }
+ })
+ }
+ }
+ Item::Type(item_type) => match *(item_type.ty) {
+ Type::Ptr(_) => {
+ let type_ = item_type.ident;
+ Some(quote! {
+ impl #trait_ for #path_prefix :: #type_ {
+ fn write(&self, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", *self)
+ }
+ }
+ })
+ }
+ Type::Path(type_path) => {
+ if type_path.path.leading_colon.is_some() {
+ let option_seg = type_path.path.segments.last().unwrap();
+ if option_seg.ident == "Option" {
+ match &option_seg.arguments {
+ PathArguments::AngleBracketed(generic) => match generic.args[0] {
+ syn::GenericArgument::Type(Type::BareFn(_)) => {
+ let type_ = &item_type.ident;
+ return Some(quote! {
+ impl #trait_ for #path_prefix :: #type_ {
+ fn write(&self, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
+ write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
+ }
+ }
+ });
+ }
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ }
+ }
+ }
+ None
+ }
+ _ => unreachable!(),
+ },
+ Item::Union(_) => None,
+ Item::Use(_) => None,
+ _ => unreachable!(),
+ }
+}
+
+struct DeriveDisplayState {
+ type_path: Path,
+ trait_: Path,
+ ignore_structs: FxHashSet<Ident>,
+ ignore_fns: FxHashSet<Ident>,
+ enums: FxHashMap<Ident, Vec<Ident>>,
+}
+
+impl DeriveDisplayState {
+ fn new(input: DeriveDisplayInput) -> Self {
+ DeriveDisplayState {
+ type_path: input.type_path,
+ trait_: input.trait_,
+ ignore_structs: input.ignore_structs.into_iter().collect(),
+ ignore_fns: input.ignore_fns.into_iter().collect(),
+ enums: Default::default(),
+ }
+ }
+
+ fn record_enum_variant(&mut self, enum_: Ident, variant: Ident) {
+ match self.enums.entry(enum_) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().push(variant);
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(vec![variant]);
+ }
+ }
+ }
+}
+
+struct DeriveDisplayInput {
+ type_path: Path,
+ trait_: Path,
+ ignore_structs: Punctuated<Ident, Token![,]>,
+ ignore_fns: Punctuated<Ident, Token![,]>,
+}
+
+impl Parse for DeriveDisplayInput {
+ fn parse(input: ParseStream) -> syn::Result<Self> {
+ let type_path = input.parse::<Path>()?;
+ input.parse::<Token![,]>()?;
+ let trait_ = input.parse::<Path>()?;
+ input.parse::<Token![,]>()?;
+ let ignore_structs_buffer;
+ bracketed!(ignore_structs_buffer in input);
+ let ignore_structs = ignore_structs_buffer.parse_terminated(Ident::parse)?;
+ input.parse::<Token![,]>()?;
+ let ignore_fns_buffer;
+ bracketed!(ignore_fns_buffer in input);
+ let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse)?;
+ Ok(Self {
+ type_path,
+ trait_,
+ ignore_structs,
+ ignore_fns,
+ })
+ }
+}