use proc_macro2::Span; use proc_macro2::TokenStream; use quote::quote; use quote::ToTokens; use rustc_hash::FxHashMap; use std::fmt::Write; use syn::bracketed; use syn::parse::Peek; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::LitInt; use syn::Type; use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; pub struct ParseDefinitions { pub token_type: ItemEnum, pub additional_enums: FxHashMap, pub definitions: Vec, } impl Parse for ParseDefinitions { fn parse(input: syn::parse::ParseStream) -> syn::Result { let token_type = input.parse::()?; let mut additional_enums = FxHashMap::default(); while input.peek(Token![#]) { let enum_ = input.parse::()?; additional_enums.insert(enum_.ident.clone(), enum_); } let mut definitions = Vec::new(); while !input.is_empty() { definitions.push(input.parse::()?); } Ok(Self { token_type, additional_enums, definitions, }) } } pub struct OpcodeDefinition(pub Patterns, pub Vec); impl Parse for OpcodeDefinition { fn parse(input: syn::parse::ParseStream) -> syn::Result { let patterns = input.parse::()?; let mut rules = Vec::new(); while Rule::peek(input) { rules.push(input.parse::()?); input.parse::()?; } Ok(Self(patterns, rules)) } } pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); impl Parse for Patterns { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut result = Vec::new(); loop { if !OpcodeDecl::peek(input) { break; } let decl = input.parse::()?; let code_block = input.parse::()?; result.push((decl, code_block)) } Ok(Self(result)) } } pub struct OpcodeDecl(pub Instruction, pub Arguments); impl OpcodeDecl { fn peek(input: syn::parse::ParseStream) -> bool { Instruction::peek(input) && !input.peek2(Token![=]) } } impl Parse for OpcodeDecl { fn parse(input: syn::parse::ParseStream) -> syn::Result { Ok(Self( input.parse::()?, input.parse::()?, )) } } pub struct CodeBlock { pub special: bool, pub code: proc_macro2::Group, } impl Parse for CodeBlock { fn parse(input: syn::parse::ParseStream) -> syn::Result { let lookahead = input.lookahead1(); let (special, code) = if lookahead.peek(Token![<]) { input.parse::()?; input.parse::()?; //input.parse::]>()?; (true, input.parse::()?) } else if lookahead.peek(Token![=]) { input.parse::()?; input.parse::]>()?; (false, input.parse::()?) } else { return Err(lookahead.error()); }; Ok(Self { special, code }) } } pub struct Rule { pub modifier: Option, pub type_: Option, pub alternatives: Vec, } impl Rule { fn peek(input: syn::parse::ParseStream) -> bool { DotModifier::peek(input) || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>])) } fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { let mut result = Vec::new(); Self::parse_with_alternative(input, &mut result)?; loop { if !input.peek(Token![,]) { break; } input.parse::()?; Self::parse_with_alternative(input, &mut result)?; } Ok(result) } fn parse_with_alternative( input: &syn::parse::ParseBuffer, result: &mut Vec, ) -> Result<(), syn::Error> { input.parse::()?; let part1 = input.parse::()?; if input.peek(token::Brace) { result.push(DotModifier { part1: part1.clone(), part2: None, }); let suffix_content; braced!(suffix_content in input); let suffixes = Punctuated::::parse_separated_nonempty( &suffix_content, )?; for part2 in suffixes { result.push(DotModifier { part1: part1.clone(), part2: Some(part2), }); } } else if IdentOrTypeSuffix::peek(input) { let part2 = Some(IdentOrTypeSuffix::parse(input)?); result.push(DotModifier { part1, part2 }); } else { result.push(DotModifier { part1, part2: None }); } Ok(()) } } #[derive(PartialEq, Eq, Hash, Clone)] struct IdentOrTypeSuffix(IdentLike); impl IdentOrTypeSuffix { fn span(&self) -> Span { self.0.span() } fn peek(input: syn::parse::ParseStream) -> bool { input.peek(Token![::]) } } impl ToTokens for IdentOrTypeSuffix { fn to_tokens(&self, tokens: &mut TokenStream) { let ident = &self.0; quote! { :: #ident }.to_tokens(tokens) } } impl Parse for IdentOrTypeSuffix { fn parse(input: syn::parse::ParseStream) -> syn::Result { input.parse::()?; Ok(Self(input.parse::()?)) } } impl Parse for Rule { fn parse(input: syn::parse::ParseStream) -> syn::Result { let (modifier, type_) = if DotModifier::peek(input) { let modifier = Some(input.parse::()?); if input.peek(Token![:]) { input.parse::()?; (modifier, Some(input.parse::()?)) } else { (modifier, None) } } else { (None, Some(input.parse::()?)) }; input.parse::()?; let content; braced!(content in input); let alternatives = Self::parse_alternatives(&content)?; Ok(Self { modifier, type_, alternatives, }) } } pub struct Instruction { pub name: Ident, pub modifiers: Vec, } impl Instruction { fn peek(input: syn::parse::ParseStream) -> bool { input.peek(Ident) } } impl Parse for Instruction { fn parse(input: syn::parse::ParseStream) -> syn::Result { let instruction = input.parse::()?; let mut modifiers = Vec::new(); loop { if !MaybeDotModifier::peek(input) { break; } modifiers.push(MaybeDotModifier::parse(input)?); } Ok(Self { name: instruction, modifiers, }) } } pub struct MaybeDotModifier { pub optional: bool, pub modifier: DotModifier, } impl MaybeDotModifier { fn peek(input: syn::parse::ParseStream) -> bool { input.peek(token::Brace) || DotModifier::peek(input) } } impl Parse for MaybeDotModifier { fn parse(input: syn::parse::ParseStream) -> syn::Result { Ok(if input.peek(token::Brace) { let content; braced!(content in input); let modifier = DotModifier::parse(&content)?; Self { modifier, optional: true, } } else { let modifier = DotModifier::parse(input)?; Self { modifier, optional: false, } }) } } #[derive(PartialEq, Eq, Hash, Clone)] pub struct DotModifier { part1: IdentLike, part2: Option, } impl std::fmt::Display for DotModifier { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, ".")?; self.part1.fmt(f)?; if let Some(ref part2) = self.part2 { write!(f, "::")?; part2.0.fmt(f)?; } Ok(()) } } impl std::fmt::Debug for DotModifier { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self, f) } } impl DotModifier { pub fn span(&self) -> Span { let part1 = self.part1.span(); if let Some(ref part2) = self.part2 { part1.join(part2.span()).unwrap_or(part1) } else { part1 } } pub fn ident(&self) -> Ident { let mut result = String::new(); write!(&mut result, "{}", self.part1).unwrap(); if let Some(ref part2) = self.part2 { write!(&mut result, "_{}", part2.0).unwrap(); } else { match self.part1 { IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'), IdentLike::Ident(_) | IdentLike::Integer(_) => {} } } Ident::new(&result.to_ascii_lowercase(), self.span()) } pub fn variant_capitalized(&self) -> Ident { self.capitalized_impl(String::new()) } pub fn dot_capitalized(&self) -> Ident { self.capitalized_impl("Dot".to_string()) } fn capitalized_impl(&self, prefix: String) -> Ident { let mut temp = String::new(); write!(&mut temp, "{}", &self.part1).unwrap(); if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { write!(&mut temp, "_{}", part2).unwrap(); } let mut result = prefix; let mut capitalize = true; for c in temp.chars() { if c == '_' { capitalize = true; continue; } // Special hack to emit `BF16`` instead of `Bf16`` let c = if capitalize || c == 'f' && result.ends_with('B') { capitalize = false; c.to_ascii_uppercase() } else { c }; result.push(c); } Ident::new(&result, self.span()) } pub fn tokens(&self) -> TokenStream { let part1 = &self.part1; let part2 = &self.part2; match self.part2 { None => quote! { . #part1 }, Some(_) => quote! { . #part1 #part2 }, } } fn peek(input: syn::parse::ParseStream) -> bool { input.peek(Token![.]) } } impl Parse for DotModifier { fn parse(input: syn::parse::ParseStream) -> syn::Result { input.parse::()?; let part1 = input.parse::()?; if IdentOrTypeSuffix::peek(input) { let part2 = Some(IdentOrTypeSuffix::parse(input)?); Ok(Self { part1, part2 }) } else { Ok(Self { part1, part2: None }) } } } #[derive(PartialEq, Eq, Hash, Clone)] enum IdentLike { Type(Token![type]), Const(Token![const]), Ident(Ident), Integer(LitInt), } impl IdentLike { fn span(&self) -> Span { match self { IdentLike::Type(c) => c.span(), IdentLike::Const(t) => t.span(), IdentLike::Ident(i) => i.span(), IdentLike::Integer(l) => l.span(), } } } impl std::fmt::Display for IdentLike { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { IdentLike::Type(_) => f.write_str("type"), IdentLike::Const(_) => f.write_str("const"), IdentLike::Ident(ident) => write!(f, "{}", ident), IdentLike::Integer(integer) => write!(f, "{}", integer), } } } impl ToTokens for IdentLike { fn to_tokens(&self, tokens: &mut TokenStream) { match self { IdentLike::Type(_) => quote! { type }.to_tokens(tokens), IdentLike::Const(_) => quote! { const }.to_tokens(tokens), IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), } } } impl Parse for IdentLike { fn parse(input: syn::parse::ParseStream) -> syn::Result { let lookahead = input.lookahead1(); Ok(if lookahead.peek(Token![const]) { IdentLike::Const(input.parse::()?) } else if lookahead.peek(Token![type]) { IdentLike::Type(input.parse::()?) } else if lookahead.peek(Ident) { IdentLike::Ident(input.parse::()?) } else if lookahead.peek(LitInt) { IdentLike::Integer(input.parse::()?) } else { return Err(lookahead.error()); }) } } // Arguments decalaration can loook like this: // a{, b} // That's why we don't parse Arguments as Punctuated #[derive(PartialEq, Eq)] pub struct Arguments(pub Vec); impl Parse for Arguments { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut result = Vec::new(); loop { if input.peek(Token![,]) { input.parse::()?; } let mut optional = false; let mut can_be_negated = false; let mut pre_pipe = false; let ident; let lookahead = input.lookahead1(); if lookahead.peek(token::Brace) { let content; braced!(content in input); let lookahead = content.lookahead1(); if lookahead.peek(Token![!]) { content.parse::()?; can_be_negated = true; ident = input.parse::()?; } else if lookahead.peek(Token![,]) { optional = true; content.parse::()?; ident = content.parse::()?; } else { return Err(lookahead.error()); } } else if lookahead.peek(token::Bracket) { let bracketed; bracketed!(bracketed in input); if bracketed.peek(Token![|]) { optional = true; bracketed.parse::()?; pre_pipe = true; ident = bracketed.parse::()?; } else { let mut sub_args = Self::parse(&bracketed)?; sub_args.0.first_mut().unwrap().pre_bracket = true; sub_args.0.last_mut().unwrap().post_bracket = true; if peek_brace_token(input, Token![.]) { let optional_suffix; braced!(optional_suffix in input); optional_suffix.parse::()?; let unified_ident = optional_suffix.parse::()?; if unified_ident.to_string() != "unified" { return Err(syn::Error::new( unified_ident.span(), format!("Exptected `unified`, got `{}`", unified_ident), )); } for a in sub_args.0.iter_mut() { a.unified = true; } } result.extend(sub_args.0); continue; } } else if lookahead.peek(Ident) { ident = input.parse::()?; } else if lookahead.peek(Token![|]) { input.parse::()?; pre_pipe = true; ident = input.parse::()?; } else { break; } result.push(Argument { optional, pre_pipe, can_be_negated, pre_bracket: false, ident, post_bracket: false, unified: false, }); } Ok(Self(result)) } } // This is effectively input.peek(token::Brace) && input.peek2(Token![.]) // input.peek2 is supposed to skip over next token, but it skips over whole // braced token group. Not sure if it's a bug fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { use syn::token::Token; let cursor = input.cursor(); cursor .group(proc_macro2::Delimiter::Brace) .map_or(false, |(content, ..)| T::Token::peek(content)) } #[derive(PartialEq, Eq)] pub struct Argument { pub optional: bool, pub pre_bracket: bool, pub pre_pipe: bool, pub can_be_negated: bool, pub ident: Ident, pub post_bracket: bool, pub unified: bool, } #[cfg(test)] mod tests { use super::{Arguments, DotModifier, MaybeDotModifier}; use quote::{quote, ToTokens}; #[test] fn parse_modifier_complex() { let input = quote! { .level::eviction_priority }; let modifier = syn::parse2::(input).unwrap(); assert_eq!( ". level :: eviction_priority", modifier.tokens().to_string() ); } #[test] fn parse_modifier_optional() { let input = quote! { { .level::eviction_priority } }; let maybe_modifider = syn::parse2::(input).unwrap(); assert_eq!( ". level :: eviction_priority", maybe_modifider.modifier.tokens().to_string() ); assert!(maybe_modifider.optional); } #[test] fn parse_type_token() { let input = quote! { . type }; let maybe_modifier = syn::parse2::(input).unwrap(); assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); assert!(!maybe_modifier.optional); } #[test] fn arguments_memory() { let input = quote! { [a], b }; let arguments = syn::parse2::(input).unwrap(); let a = &arguments.0[0]; assert!(!a.optional); assert_eq!("a", a.ident.to_string()); assert!(a.pre_bracket); assert!(!a.pre_pipe); assert!(a.post_bracket); assert!(!a.can_be_negated); let b = &arguments.0[1]; assert!(!b.optional); assert_eq!("b", b.ident.to_string()); assert!(!b.pre_bracket); assert!(!b.pre_pipe); assert!(!b.post_bracket); assert!(!b.can_be_negated); } #[test] fn arguments_optional() { let input = quote! { b{, cache_policy} }; let arguments = syn::parse2::(input).unwrap(); let b = &arguments.0[0]; assert!(!b.optional); assert_eq!("b", b.ident.to_string()); assert!(!b.pre_bracket); assert!(!b.pre_pipe); assert!(!b.post_bracket); assert!(!b.can_be_negated); let cache_policy = &arguments.0[1]; assert!(cache_policy.optional); assert_eq!("cache_policy", cache_policy.ident.to_string()); assert!(!cache_policy.pre_bracket); assert!(!cache_policy.pre_pipe); assert!(!cache_policy.post_bracket); assert!(!cache_policy.can_be_negated); } #[test] fn arguments_optional_pred() { let input = quote! { p[|q], a }; let arguments = syn::parse2::(input).unwrap(); assert_eq!(arguments.0.len(), 3); let p = &arguments.0[0]; assert!(!p.optional); assert_eq!("p", p.ident.to_string()); assert!(!p.pre_bracket); assert!(!p.pre_pipe); assert!(!p.post_bracket); assert!(!p.can_be_negated); let q = &arguments.0[1]; assert!(q.optional); assert_eq!("q", q.ident.to_string()); assert!(!q.pre_bracket); assert!(q.pre_pipe); assert!(!q.post_bracket); assert!(!q.can_be_negated); let a = &arguments.0[2]; assert!(!a.optional); assert_eq!("a", a.ident.to_string()); assert!(!a.pre_bracket); assert!(!a.pre_pipe); assert!(!a.post_bracket); assert!(!a.can_be_negated); } #[test] fn arguments_optional_with_negate() { let input = quote! { b, {!}c }; let arguments = syn::parse2::(input).unwrap(); assert_eq!(arguments.0.len(), 2); let b = &arguments.0[0]; assert!(!b.optional); assert_eq!("b", b.ident.to_string()); assert!(!b.pre_bracket); assert!(!b.pre_pipe); assert!(!b.post_bracket); assert!(!b.can_be_negated); let c = &arguments.0[1]; assert!(!c.optional); assert_eq!("c", c.ident.to_string()); assert!(!c.pre_bracket); assert!(!c.pre_pipe); assert!(!c.post_bracket); assert!(c.can_be_negated); } #[test] fn arguments_tex() { let input = quote! { d[|p], [a{, b}, c], dpdx, dpdy {, e} }; let arguments = syn::parse2::(input).unwrap(); assert_eq!(arguments.0.len(), 8); { let d = &arguments.0[0]; assert!(!d.optional); assert_eq!("d", d.ident.to_string()); assert!(!d.pre_bracket); assert!(!d.pre_pipe); assert!(!d.post_bracket); assert!(!d.can_be_negated); } { let p = &arguments.0[1]; assert!(p.optional); assert_eq!("p", p.ident.to_string()); assert!(!p.pre_bracket); assert!(p.pre_pipe); assert!(!p.post_bracket); assert!(!p.can_be_negated); } { let a = &arguments.0[2]; assert!(!a.optional); assert_eq!("a", a.ident.to_string()); assert!(a.pre_bracket); assert!(!a.pre_pipe); assert!(!a.post_bracket); assert!(!a.can_be_negated); } { let b = &arguments.0[3]; assert!(b.optional); assert_eq!("b", b.ident.to_string()); assert!(!b.pre_bracket); assert!(!b.pre_pipe); assert!(!b.post_bracket); assert!(!b.can_be_negated); } { let c = &arguments.0[4]; assert!(!c.optional); assert_eq!("c", c.ident.to_string()); assert!(!c.pre_bracket); assert!(!c.pre_pipe); assert!(c.post_bracket); assert!(!c.can_be_negated); } { let dpdx = &arguments.0[5]; assert!(!dpdx.optional); assert_eq!("dpdx", dpdx.ident.to_string()); assert!(!dpdx.pre_bracket); assert!(!dpdx.pre_pipe); assert!(!dpdx.post_bracket); assert!(!dpdx.can_be_negated); } { let dpdy = &arguments.0[6]; assert!(!dpdy.optional); assert_eq!("dpdy", dpdy.ident.to_string()); assert!(!dpdy.pre_bracket); assert!(!dpdy.pre_pipe); assert!(!dpdy.post_bracket); assert!(!dpdy.can_be_negated); } { let e = &arguments.0[7]; assert!(e.optional); assert_eq!("e", e.ident.to_string()); assert!(!e.pre_bracket); assert!(!e.pre_pipe); assert!(!e.post_bracket); assert!(!e.can_be_negated); } } #[test] fn rule_multi() { let input = quote! { .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } }; let rule = syn::parse2::(input).unwrap(); assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string()); assert_eq!( "StateSpace", rule.type_.unwrap().to_token_stream().to_string() ); let alts = rule .alternatives .iter() .map(|m| m.tokens().to_string()) .collect::>(); assert_eq!( vec![ ". global", ". local", ". param", ". param :: func", ". shared", ". shared :: cta", ". shared :: cluster" ], alts ); } #[test] fn rule_multi2() { let input = quote! { .cop: StCacheOperator = { .wb, .cg, .cs, .wt } }; let rule = syn::parse2::(input).unwrap(); assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string()); assert_eq!( "StCacheOperator", rule.type_.unwrap().to_token_stream().to_string() ); let alts = rule .alternatives .iter() .map(|m| m.tokens().to_string()) .collect::>(); assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); } #[test] fn args_unified() { let input = quote! { d, [a]{.unified}{, cache_policy} }; let args = syn::parse2::(input).unwrap(); let a = &args.0[1]; assert!(!a.optional); assert_eq!("a", a.ident.to_string()); assert!(a.pre_bracket); assert!(!a.pre_pipe); assert!(a.post_bracket); assert!(!a.can_be_negated); assert!(a.unified); } #[test] fn special_block() { let input = quote! { bra <= { bra(stream) } }; syn::parse2::(input).unwrap(); } }