aboutsummaryrefslogtreecommitdiffhomepage
path: root/cuda_base/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cuda_base/src/lib.rs')
-rw-r--r--cuda_base/src/lib.rs598
1 files changed, 157 insertions, 441 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 3f6f779..833d372 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -1,110 +1,25 @@
extern crate proc_macro;
-use std::collections::hash_map;
-use std::iter;
-
use proc_macro::TokenStream;
use proc_macro2::Span;
-use quote::{format_ident, quote, ToTokens};
-use rustc_hash::{FxHashMap, FxHashSet};
+use quote::{quote, ToTokens};
+use rustc_hash::FxHashMap;
+use std::iter;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::visit_mut::VisitMut;
use syn::{
- bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident,
- Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature,
- Token, Type, TypeArray, TypePath, TypePtr,
+ bracketed, parse_macro_input, File, ForeignItem, ForeignItemFn, Ident, Item, Path, Signature,
+ Token,
};
const CUDA_RS: &'static str = include_str! {"cuda.rs"};
-// This macro copies cuda.rs as-is with some changes:
-// * All function declarations are filtered out
-// * CUdeviceptr_v2 is redefined from `unsigned long long` to `*void`
-// * `extern "C"` gets replaced by `extern "system"`
-// * CUuuid_st is redefined to use uchar instead of char
-#[proc_macro]
-pub fn cuda_type_declarations(_: TokenStream) -> TokenStream {
- let mut cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
- cuda_module.items = cuda_module
- .items
- .into_iter()
- .filter_map(|item| match item {
- Item::ForeignMod(_) => None,
- Item::Struct(mut struct_) => {
- if "CUdeviceptr_v2" == struct_.ident.to_string() {
- match &mut struct_.fields {
- Fields::Unnamed(ref mut fields) => {
- fields.unnamed[0].ty =
- absolute_path_to_mut_ptr(&["std", "os", "raw", "c_void"])
- }
- _ => unreachable!(),
- }
- } else if "CUuuid_st" == struct_.ident.to_string() {
- match &mut struct_.fields {
- Fields::Named(ref mut fields) => match fields.named[0].ty {
- Type::Array(TypeArray { ref mut elem, .. }) => {
- *elem = Box::new(Type::Path(TypePath {
- qself: None,
- path: segments_to_path(&["std", "os", "raw", "c_uchar"]),
- }))
- }
- _ => unreachable!(),
- },
- _ => panic!(),
- }
- }
- Some(Item::Struct(struct_))
- }
- i => Some(i),
- })
- .collect::<Vec<_>>();
- syn::visit_mut::visit_file_mut(&mut FixAbi, &mut cuda_module);
- cuda_module.into_token_stream().into()
-}
-
-fn segments_to_path(path: &[&'static str]) -> Path {
- let mut segments = Punctuated::new();
- for ident in path {
- let ident = PathSegment {
- ident: Ident::new(ident, Span::call_site()),
- arguments: PathArguments::None,
- };
- segments.push(ident);
- }
- Path {
- leading_colon: Some(Token![::](Span::call_site())),
- segments,
- }
-}
-
-fn absolute_path_to_mut_ptr(path: &[&'static str]) -> Type {
- Type::Ptr(TypePtr {
- star_token: Token![*](Span::call_site()),
- const_token: None,
- mutability: Some(Token![mut](Span::call_site())),
- elem: Box::new(Type::Path(TypePath {
- qself: None,
- path: segments_to_path(path),
- })),
- })
-}
-
-struct FixAbi;
-
-impl VisitMut for FixAbi {
- fn visit_abi_mut(&mut self, i: &mut Abi) {
- if let Some(ref mut name) = i.name {
- *name = LitStr::new("system", Span::call_site());
- }
- }
-}
-
// This macro accepts following arguments:
-// * `type_path`: path to the module with type definitions (in the module tree)
// * `normal_macro`: ident for a normal macro
-// * `override_macro`: ident for an override macro
-// * `override_fns`: list of override functions
+// * zero or more:
+// * `override_macro`: ident for an override macro
+// * `override_fns`: list of override functions
// Then macro goes through every function in rust.rs, and for every fn `foo`:
// * if `foo` is contained in `override_fns` then pass it into `override_macro`
// * if `foo` is not contained in `override_fns` pass it to `normal_macro`
@@ -117,390 +32,191 @@ impl VisitMut for FixAbi {
#[proc_macro]
pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as FnDeclInput);
- let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
- let override_fns = input
- .override_fns
- .iter()
- .map(ToString::to_string)
- .collect::<FxHashSet<_>>();
- let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module
- .items
- .into_iter()
- .filter_map(|item| match item {
- Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() {
- ForeignItem::Fn(ForeignItemFn {
- sig:
- Signature {
- ident,
- inputs,
- output,
- ..
- },
- ..
- }) => {
- let use_normal_macro = !override_fns.contains(&ident.to_string());
- let inputs = inputs
- .into_iter()
- .map(|fn_arg| match fn_arg {
- FnArg::Typed(mut pat_type) => {
- pat_type.ty =
- prepend_cuda_path_to_type(&input.type_path, pat_type.ty);
- FnArg::Typed(pat_type)
- }
- _ => unreachable!(),
- })
- .collect::<Punctuated<_, Token![,]>>();
- let output = match output {
- ReturnType::Type(_, type_) => type_,
- ReturnType::Default => unreachable!(),
- };
- let type_path = input.type_path.clone();
- Some((
- quote! {
- "system" fn #ident(#inputs) -> #type_path :: #output
- },
- use_normal_macro,
- ))
- }
- _ => unreachable!(),
- },
- _ => None,
- })
- .partition(|(_, use_normal_macro)| *use_normal_macro);
- let mut result = proc_macro2::TokenStream::new();
- if !normal_macro_args.is_empty() {
- let punctuated_normal_macro_args = to_punctuated::<Token![;]>(normal_macro_args);
- let macro_ = &input.normal_macro;
- result.extend(iter::once(quote! {
- #macro_ ! (#punctuated_normal_macro_args);
- }));
- }
- if !override_macro_args.is_empty() {
- let punctuated_override_macro_args = to_punctuated::<Token![;]>(override_macro_args);
- let macro_ = &input.override_macro;
- result.extend(iter::once(quote! {
- #macro_ ! (#punctuated_override_macro_args);
- }));
- }
- result.into()
-}
-
-fn to_punctuated<P: ToTokens + Default>(
- elms: Vec<(proc_macro2::TokenStream, bool)>,
-) -> proc_macro2::TokenStream {
- let mut collection = Punctuated::<proc_macro2::TokenStream, P>::new();
- collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream));
- collection.into_token_stream()
-}
-
-fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
- match *type_ {
- Type::Path(mut type_path) => {
- type_path.path = prepend_cuda_path_to_path(base_path, type_path.path);
- Box::new(Type::Path(type_path))
- }
- Type::Ptr(mut type_ptr) => {
- type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem);
- Box::new(Type::Ptr(type_ptr))
+ let mut choose_macro = ChooseMacro::new(input);
+ let mut cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
+ syn::visit_mut::visit_file_mut(&mut FixFnSignatures, &mut cuda_module);
+ let extern_ = if let Item::ForeignMod(extern_) = cuda_module.items.pop().unwrap() {
+ extern_
+ } else {
+ unreachable!()
+ };
+ let abi = extern_.abi.name;
+ for mut item in extern_.items {
+ if let ForeignItem::Fn(ForeignItemFn {
+ sig: Signature { ref ident, .. },
+ ref mut attrs,
+ ..
+ }) = item
+ {
+ *attrs = Vec::new();
+ choose_macro.add(ident, quote! { #abi #item });
+ } else {
+ unreachable!()
}
- _ => unreachable!(),
- }
-}
-
-fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
- if path.leading_colon.is_some() {
- return path;
}
- if path.segments.len() == 1 {
- let ident = path.segments[0].ident.to_string();
- if ident.starts_with("CU")
- || ident.starts_with("cu")
- || ident.starts_with("GL")
- || ident == "HGPUNV"
- {
- let mut base_path = base_path.clone();
- base_path.segments.extend(path.segments);
- return base_path;
+ let mut result = proc_macro2::TokenStream::new();
+ for (path, items) in
+ iter::once(choose_macro.default).chain(choose_macro.override_sets.into_iter())
+ {
+ if items.is_empty() {
+ continue;
+ }
+ quote! {
+ #path ! { #(#items)* }
}
+ .to_tokens(&mut result);
}
- path
+ result.into()
}
-
struct FnDeclInput {
- type_path: Path,
normal_macro: Path,
- override_macro: Path,
- override_fns: Punctuated<Ident, Token![,]>,
+ overrides: Punctuated<OverrideMacro, Token![,]>,
}
impl Parse for FnDeclInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
- let type_path = input.parse::<Path>()?;
- input.parse::<Token![,]>()?;
let normal_macro = input.parse::<Path>()?;
- input.parse::<Token![,]>()?;
- let override_macro = input.parse::<Path>()?;
- input.parse::<Token![,]>()?;
- let override_fns_content;
- bracketed!(override_fns_content in input);
- let override_fns = override_fns_content.parse_terminated(Ident::parse)?;
+ let overrides = if input.is_empty() {
+ Punctuated::new()
+ } else {
+ input.parse::<Token![,]>()?;
+ input.parse_terminated(OverrideMacro::parse, Token![,])?
+ };
Ok(Self {
- type_path,
normal_macro,
- override_macro,
- override_fns,
+ overrides,
})
}
}
+struct OverrideMacro {
+ macro_: Path,
+ functions: Punctuated<Ident, Token![,]>,
+}
+
+impl Parse for OverrideMacro {
+ fn parse(input: ParseStream) -> syn::Result<Self> {
+ let macro_ = input.parse::<Path>()?;
+ input.parse::<Token![<=]>()?;
+ let functions_content;
+ bracketed!(functions_content in input);
+ let functions = functions_content.parse_terminated(Ident::parse, Token![,])?;
+ Ok(Self { macro_, functions })
+ }
+}
-// This trait accepts following parameters:
-// * `type_path`: path to the module with type definitions (in the module tree)
-// * `trait_`: name of the trait to be derived
-// * `ignore_types`: bracketed list of types to ignore
-// * `ignore_fns`: bracketed list of fns to ignore
-#[proc_macro]
-pub fn cuda_derive_display_trait(tokens: TokenStream) -> TokenStream {
- let input = parse_macro_input!(tokens as DeriveDisplayInput);
- let cuda_module = syn::parse_str::<File>(CUDA_RS).unwrap();
- let mut derive_state = DeriveDisplayState::new(input);
- cuda_module
- .items
- .into_iter()
- .filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i))
- .collect::<proc_macro2::TokenStream>()
- .into()
+struct ChooseMacro {
+ default: (Path, Vec<proc_macro2::TokenStream>),
+ override_lookup: FxHashMap<Ident, Path>,
+ override_sets: FxHashMap<Path, Vec<proc_macro2::TokenStream>>,
}
-fn cuda_derive_display_trait_for_item(
- state: &mut DeriveDisplayState,
- item: Item,
-) -> Option<proc_macro2::TokenStream> {
- let path_prefix = &state.type_path;
- let path_prefix_iter = iter::repeat(&path_prefix);
- let trait_ = &state.trait_;
- let trait_iter = iter::repeat(&state.trait_);
- match item {
- Item::Const(_) => None,
- Item::ForeignMod(ItemForeignMod { mut items, .. }) => match items.pop().unwrap() {
- ForeignItem::Fn(ForeignItemFn {
- sig: Signature { ident, inputs, .. },
- ..
- }) => {
- if state.ignore_fns.contains(&ident) {
- return None;
- }
- let inputs = inputs
- .into_iter()
- .map(|fn_arg| match fn_arg {
- FnArg::Typed(mut pat_type) => {
- pat_type.ty = prepend_cuda_path_to_type(path_prefix, pat_type.ty);
- FnArg::Typed(pat_type)
- }
- _ => unreachable!(),
- })
- .collect::<Vec<_>>();
- let inputs_iter = inputs.iter();
- let mut arg_name_iter = inputs.iter().map(|fn_arg| match fn_arg {
- FnArg::Typed(PatType { pat, .. }) => pat,
- _ => unreachable!(),
- });
- let fn_name = format_ident!("write_{}", ident);
- let original_fn_name = ident.to_string();
- Some(match arg_name_iter.next() {
- Some(first_arg_name) => quote! {
- pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> {
- writer.write_all(concat!("(", stringify!(#first_arg_name), ": ").as_bytes())?;
- let mut arg_idx = 0usize;
- CudaDisplay::write(&#first_arg_name, #original_fn_name, arg_idx, writer)?;
- #(
- writer.write_all(b", ")?;
- writer.write_all(concat!(stringify!(#arg_name_iter), ": ").as_bytes())?;
- CudaDisplay::write(&#arg_name_iter, #original_fn_name, arg_idx, writer)?;
- arg_idx += 1;
- )*
- writer.write_all(b")")
- }
- },
- None => quote! {
- pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- writer.write_all(b"()")
- }
- },
- })
+impl ChooseMacro {
+ fn new(input: FnDeclInput) -> Self {
+ let mut override_lookup = FxHashMap::default();
+ let mut override_sets = FxHashMap::default();
+ for OverrideMacro { macro_, functions } in input.overrides {
+ for ident in functions {
+ override_lookup.insert(ident, macro_.clone());
+ override_sets.insert(macro_.clone(), Vec::new());
}
- _ => unreachable!(),
- },
- Item::Impl(mut item_impl) => {
- let enum_ = match *(item_impl.self_ty) {
- Type::Path(mut path) => path.path.segments.pop().unwrap().into_value().ident,
- _ => unreachable!(),
- };
- let variant_ = match item_impl.items.pop().unwrap() {
- syn::ImplItem::Const(item_const) => item_const.ident,
- _ => unreachable!(),
- };
- state.record_enum_variant(enum_, variant_);
- None
}
- Item::Struct(item_struct) => {
- let item_struct_name = item_struct.ident.to_string();
- if state.ignore_types.contains(&item_struct.ident) {
- return None;
- }
- if item_struct_name.ends_with("_enum") {
- let enum_ = &item_struct.ident;
- let enum_iter = iter::repeat(&item_struct.ident);
- let variants = state.enums.get(&item_struct.ident).unwrap().iter();
- Some(quote! {
- impl #trait_ for #path_prefix :: #enum_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- match self {
- #(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)*
- _ => write!(writer, "{}", self.0)
- }
- }
- }
- })
- } else {
- let struct_ = &item_struct.ident;
- let (first_field, rest_of_fields) = match item_struct.fields {
- Fields::Named(fields) => {
- let mut all_idents = fields.named.into_iter().filter_map(|f| {
- let f_ident = f.ident.unwrap();
- let name = f_ident.to_string();
- if name.starts_with("reserved") || name == "_unused" {
- None
- } else {
- Some(f_ident)
- }
- });
- let first = match all_idents.next() {
- Some(f) => f,
- None => return None,
- };
- (first, all_idents)
- }
- _ => return None,
- };
- Some(quote! {
- impl #trait_ for #path_prefix :: #struct_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
- #trait_::write(&self.#first_field, "", 0, writer)?;
- #(
- writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
- #trait_iter::write(&self.#rest_of_fields, "", 0, writer)?;
- )*
- writer.write_all(b" }")
- }
- }
- })
- }
+ Self {
+ default: (input.normal_macro, Vec::new()),
+ override_lookup,
+ override_sets,
}
- Item::Type(item_type) => {
- if state.ignore_types.contains(&item_type.ident) {
- return None;
- };
- match *(item_type.ty) {
- Type::Ptr(_) => {
- let type_ = item_type.ident;
- Some(quote! {
- impl #trait_ for #path_prefix :: #type_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- write!(writer, "{:p}", *self)
- }
- }
- })
- }
- Type::Path(type_path) => {
- if type_path.path.leading_colon.is_some() {
- let option_seg = type_path.path.segments.last().unwrap();
- if option_seg.ident == "Option" {
- match &option_seg.arguments {
- PathArguments::AngleBracketed(generic) => match generic.args[0] {
- syn::GenericArgument::Type(Type::BareFn(_)) => {
- let type_ = &item_type.ident;
- return Some(quote! {
- impl #trait_ for #path_prefix :: #type_ {
- fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
- write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
- }
- }
- });
- }
- _ => unreachable!(),
- },
- _ => unreachable!(),
- }
- }
- }
- None
- }
- _ => unreachable!(),
+ }
+
+ fn add(&mut self, ident: &Ident, tokens: proc_macro2::TokenStream) {
+ match self.override_lookup.get(ident) {
+ Some(override_macro) => {
+ self.override_sets
+ .get_mut(override_macro)
+ .unwrap()
+ .push(tokens);
}
+ None => self.default.1.push(tokens),
}
- Item::Union(_) => None,
- Item::Use(_) => None,
- _ => unreachable!(),
}
}
-struct DeriveDisplayState {
- type_path: Path,
- trait_: Path,
- ignore_types: FxHashSet<Ident>,
- ignore_fns: FxHashSet<Ident>,
- enums: FxHashMap<Ident, Vec<Ident>>,
-}
+// For some reason prettyplease will append trailing comma *only*
+// if there are two or more arguments
+struct FixFnSignatures;
-impl DeriveDisplayState {
- fn new(input: DeriveDisplayInput) -> Self {
- DeriveDisplayState {
- type_path: input.type_path,
- trait_: input.trait_,
- ignore_types: input.ignore_types.into_iter().collect(),
- ignore_fns: input.ignore_fns.into_iter().collect(),
- enums: Default::default(),
- }
+impl VisitMut for FixFnSignatures {
+ fn visit_signature_mut(&mut self, s: &mut syn::Signature) {
+ s.inputs.pop_punct();
}
+}
- fn record_enum_variant(&mut self, enum_: Ident, variant: Ident) {
- match self.enums.entry(enum_) {
- hash_map::Entry::Occupied(mut entry) => {
- entry.get_mut().push(variant);
- }
- hash_map::Entry::Vacant(entry) => {
- entry.insert(vec![variant]);
- }
- }
+const MODULES: &[&str] = &[
+ "context", "device", "driver", "function", "link", "memory", "module", "pointer",
+];
+
+#[proc_macro]
+pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
+ let mut path = parse_macro_input!(tokens as syn::Path);
+ let fn_ = path
+ .segments
+ .pop()
+ .unwrap()
+ .into_tuple()
+ .0
+ .ident
+ .to_string();
+ let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
+ let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
+ let fn_path = join(segments, !already_has_module);
+ quote! {
+ #path #fn_path
}
+ .into()
}
-struct DeriveDisplayInput {
- type_path: Path,
- trait_: Path,
- ignore_types: Punctuated<Ident, Token![,]>,
- ignore_fns: Punctuated<Ident, Token![,]>,
+fn split(fn_: &str) -> Vec<String> {
+ let mut result = Vec::new();
+ for c in fn_.chars() {
+ if c.is_ascii_uppercase() {
+ result.push(c.to_ascii_lowercase().to_string());
+ } else {
+ result.last_mut().unwrap().push(c);
+ }
+ }
+ result
}
-impl Parse for DeriveDisplayInput {
- fn parse(input: ParseStream) -> syn::Result<Self> {
- let type_path = input.parse::<Path>()?;
- input.parse::<Token![,]>()?;
- let trait_ = input.parse::<Path>()?;
- input.parse::<Token![,]>()?;
- let ignore_types_buffer;
- bracketed!(ignore_types_buffer in input);
- let ignore_types = ignore_types_buffer.parse_terminated(Ident::parse)?;
- input.parse::<Token![,]>()?;
- let ignore_fns_buffer;
- bracketed!(ignore_fns_buffer in input);
- let ignore_fns = ignore_fns_buffer.parse_terminated(Ident::parse)?;
- Ok(Self {
- type_path,
- trait_,
- ignore_types,
- ignore_fns,
+fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
+ fn full_form(segment: &str) -> Option<&[&str]> {
+ Some(match segment {
+ "ctx" => &["context"],
+ "func" => &["function"],
+ "mem" => &["memory"],
+ "memcpy" => &["memory", "copy"],
+ _ => return None,
})
}
+ let mut normalized: Vec<&str> = Vec::new();
+ for segment in fn_.iter() {
+ match full_form(segment) {
+ Some(segments) => normalized.extend(segments.into_iter()),
+ None => normalized.push(&*segment),
+ }
+ }
+ if !find_module {
+ return [Ident::new(&normalized.join("_"), Span::call_site())]
+ .into_iter()
+ .collect();
+ }
+ if !MODULES.contains(&normalized[0]) {
+ let mut globalized = vec!["driver"];
+ globalized.extend(normalized);
+ normalized = globalized;
+ }
+ let (module, path) = normalized.split_first().unwrap();
+ let path = path.join("_");
+ [module, &&*path]
+ .into_iter()
+ .map(|s| Ident::new(s, Span::call_site()))
+ .collect()
}