diff options
author | Andrzej Janik <[email protected]> | 2024-11-02 15:57:57 +0100 |
---|---|---|
committer | GitHub <[email protected]> | 2024-11-02 15:57:57 +0100 |
commit | b4cb3ade63af94ccb709f2c0858253b13125fcc6 (patch) | |
tree | e0532fa9bd888b7dc526264cff870e44379b02e7 /ptx_parser/src/lib.rs | |
parent | 3870a96592c6a93d3a68391f6cbaecd9c7a2bc97 (diff) | |
download | ZLUDA-b4cb3ade63af94ccb709f2c0858253b13125fcc6.tar.gz ZLUDA-b4cb3ade63af94ccb709f2c0858253b13125fcc6.zip |
Recover from and report unknown instructions and directives (#295)
Diffstat (limited to 'ptx_parser/src/lib.rs')
-rw-r--r-- | ptx_parser/src/lib.rs | 362 |
1 files changed, 291 insertions, 71 deletions
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index b49503b..1ea2d71 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -9,7 +9,7 @@ use winnow::ascii::dec_uint; use winnow::combinator::*; use winnow::error::{ErrMode, ErrorKind}; use winnow::stream::Accumulate; -use winnow::token::any; +use winnow::token::{any, take_till}; use winnow::{ error::{ContextError, ParserError}, stream::{Offset, Stream, StreamIsPartial}, @@ -86,14 +86,16 @@ impl VectorPrefix { } struct PtxParserState<'a, 'input> { - errors: &'a mut Vec<PtxError>, + text: &'input str, + errors: &'a mut Vec<PtxError<'input>>, function_declarations: FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, } impl<'a, 'input> PtxParserState<'a, 'input> { - fn new(errors: &'a mut Vec<PtxError>) -> Self { + fn new(text: &'input str, errors: &'a mut Vec<PtxError<'input>>) -> Self { Self { + text, errors, function_declarations: FxHashMap::default(), } @@ -127,10 +129,11 @@ impl<'a, 'input> Debug for PtxParserState<'a, 'input> { } } -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; +type PtxParser<'a, 'input> = + Stateful<&'a [(Token<'input>, logos::Span)], PtxParserState<'a, 'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { if let Token::Ident(text) = t { Some(text) } else if let Some(text) = t.opcode_text() { @@ -143,7 +146,7 @@ fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> } fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { if let Token::DotIdent(text) = t { Some(text) } else { @@ -154,7 +157,7 @@ fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input } fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { Some(match t { Token::Hex(s) => { if s.ends_with('U') { @@ -178,7 +181,7 @@ fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, } fn take_error<'a, 'input: 'a, O, E>( - mut parser: impl Parser<PtxParser<'a, 'input>, Result<O, (O, PtxError)>, E>, + mut parser: impl Parser<PtxParser<'a, 'input>, Result<O, (O, PtxError<'input>)>, E>, ) -> impl Parser<PtxParser<'a, 'input>, O, E> { move |input: &mut PtxParser<'a, 'input>| { Ok(match parser.parse_next(input)? { @@ -218,7 +221,7 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast:: } fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> { - take_error(any.verify_map(|t| match t { + take_error(any.verify_map(|(t, _)| match t { Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f32::from_bits(x)), Err(err) => Err((0.0, PtxError::from(err))), @@ -229,7 +232,7 @@ fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> { } fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> { - take_error(any.verify_map(|t| match t { + take_error(any.verify_map(|(t, _)| match t { Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { Ok(x) => Ok(f64::from_bits(x)), Err(err) => Err((0.0, PtxError::from(err))), @@ -282,10 +285,9 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<as } pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> { - let lexer = Token::lexer(text); - let input = lexer.collect::<Result<Vec<_>, _>>().ok()?; + let input = lex_with_span(text).ok()?; let mut errors = Vec::new(); - let state = PtxParserState::new(&mut errors); + let state = PtxParserState::new(text, &mut errors); let parser = PtxParser { state, input: &input[..], @@ -310,7 +312,7 @@ pub fn parse_module_checked<'input>( None => break, }; match maybe_token { - Ok(token) => tokens.push(token), + Ok(token) => tokens.push((token, lexer.span())), Err(mut err) => { err.0 = lexer.span(); errors.push(PtxError::from(err)) @@ -321,7 +323,7 @@ pub fn parse_module_checked<'input>( return Err(errors); } let parse_result = { - let state = PtxParserState::new(&mut errors); + let state = PtxParserState::new(text, &mut errors); let parser = PtxParser { state, input: &tokens[..], @@ -340,6 +342,17 @@ pub fn parse_module_checked<'input>( } } +fn lex_with_span<'input>( + text: &'input str, +) -> Result<Vec<(Token<'input>, logos::Span)>, TokenError> { + let lexer = Token::lexer(text); + let mut result = Vec::new(); + for (token, span) in lexer.spanned() { + result.push((token?, span)); + } + Ok(result) +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> { ( version, @@ -385,13 +398,29 @@ fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> { fn directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> { - alt(( - function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), - file.map(|_| None), - section.map(|_| None), - (module_variable, Token::Semicolon) - .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), - )) + with_recovery( + alt(( + // When adding a new variant here remember to add its first token into recovery parser down below + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )), + take_till(1.., |(token, _)| match token { + // visibility + Token::DotExtern | Token::DotVisible | Token::DotWeak + // methods + | Token::DotFunc | Token::DotEntry + // module variables + | Token::DotGlobal | Token::DotConst | Token::DotShared + // other sections + | Token::DotFile | Token::DotSection => true, + _ => false, + }), + PtxError::UnrecognizedDirective, + ) + .map(Option::flatten) .parse_next(stream) } @@ -487,9 +516,9 @@ fn linking_directives<'a, 'input>( repeat( 0.., dispatch! { any; - Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), - Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), - Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + (Token::DotExtern, _) => empty.value(ast::LinkingDirective::EXTERN), + (Token::DotVisible, _) => empty.value(ast::LinkingDirective::VISIBLE), + (Token::DotWeak, _) => empty.value(ast::LinkingDirective::WEAK), _ => fail }, ) @@ -501,10 +530,10 @@ fn tuning_directive<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<ast::TuningDirective> { dispatch! {any; - Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), - Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), - Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), - Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + (Token::DotMaxnreg, _) => u32.map(ast::TuningDirective::MaxNReg), + (Token::DotMaxntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + (Token::DotReqntid, _) => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + (Token::DotMinnctapersm, _) => u32.map(ast::TuningDirective::MinNCtaPerSm), _ => fail } .parse_next(stream) @@ -514,10 +543,10 @@ fn method_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<ast::MethodDeclaration<'input, &'input str>> { dispatch! {any; - Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + (Token::DotEntry, _) => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None }), - Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + (Token::DotFunc, _) => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); let name = ast::MethodName::Func(name); ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } @@ -557,8 +586,8 @@ fn kernel_input<'a, 'input>( fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> { dispatch! { any; - Token::DotParam => method_parameter(StateSpace::Param), - Token::DotReg => method_parameter(StateSpace::Reg), + (Token::DotParam, _) => method_parameter(StateSpace::Param), + (Token::DotReg, _) => method_parameter(StateSpace::Reg), _ => fail } .parse_next(stream) @@ -606,8 +635,8 @@ fn function_body<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> { dispatch! {any; - Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), - Token::Semicolon => empty.map(|_| None), + (Token::LBrace, _) => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + (Token::Semicolon, _) => empty.map(|_| None), _ => fail } .parse_next(stream) @@ -616,22 +645,122 @@ fn function_body<'a, 'input>( fn statement<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> { - alt(( - label.map(Some), - debug_directive.map(|_| None), - terminated( - method_space - .flat_map(|space| multi_variable(false, space)) - .map(|var| Some(Statement::Variable(var))), - Token::Semicolon, + with_recovery( + alt(( + label.map(Some), + debug_directive.map(|_| None), + terminated( + method_space + .flat_map(|space| multi_variable(false, space)) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )), + take_till_inclusive( + |(t, _)| *t == Token::RBrace, + |(t, _)| match t { + Token::Semicolon | Token::Colon => true, + _ => false, + }, ), - predicated_instruction.map(Some), - pragma.map(|_| None), - block_statement.map(Some), - )) + PtxError::UnrecognizedStatement, + ) + .map(Option::flatten) .parse_next(stream) } +fn take_till_inclusive<I: Stream, E: ParserError<I>>( + backtrack_token: impl Fn(&I::Token) -> bool, + end_token: impl Fn(&I::Token) -> bool, +) -> impl Parser<I, <I as Stream>::Slice, E> { + fn get_offset<I: Stream>( + input: &mut I, + backtrack_token: &impl Fn(&I::Token) -> bool, + end_token: &impl Fn(&I::Token) -> bool, + should_backtrack: &mut bool, + ) -> usize { + *should_backtrack = false; + let mut hit = false; + for (offset, token) in input.iter_offsets() { + if hit { + return offset; + } else { + if backtrack_token(&token) { + *should_backtrack = true; + return offset; + } + if end_token(&token) { + hit = true; + } + } + } + input.eof_offset() + } + move |stream: &mut I| { + let mut should_backtrack = false; + let offset = get_offset(stream, &backtrack_token, &end_token, &mut should_backtrack); + let result = stream.next_slice(offset); + if should_backtrack { + Err(ErrMode::from_error_kind( + stream, + winnow::error::ErrorKind::Token, + )) + } else { + Ok(result) + } + } +} + +/* +pub fn take_till_or_backtrack_eof<Set, Input, Error>( + set: Set, +) -> impl Parser<Input, <Input as Stream>::Slice, Error> +where + Input: StreamIsPartial + Stream, + Set: winnow::stream::ContainsToken<<Input as Stream>::Token>, + Error: ParserError<Input>, +{ + move |stream: &mut Input| { + if stream.eof_offset() == 0 { + return ; + } + take_till(0.., set) + } +} + */ + +fn with_recovery<'a, 'input: 'a, T>( + mut parser: impl Parser<PtxParser<'a, 'input>, T, ContextError>, + mut recovery: impl Parser<PtxParser<'a, 'input>, &'a [(Token<'input>, logos::Span)], ContextError>, + mut error: impl FnMut(Option<&'input str>) -> PtxError<'input>, +) -> impl Parser<PtxParser<'a, 'input>, Option<T>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let input_start = stream.input.first().map(|(_, s)| s).cloned(); + let stream_start = stream.checkpoint(); + match parser.parse_next(stream) { + Ok(value) => Ok(Some(value)), + Err(ErrMode::Backtrack(_)) => { + stream.reset(&stream_start); + let tokens = recovery.parse_next(stream)?; + let range = match input_start { + Some(start) => { + Some(&stream.state.text[start.start..tokens.last().unwrap().1.end]) + } + // We could handle `(Some(start), None)``, but this whole error recovery is to + // recover from unknown instructions, so we don't care about early end of stream + _ => None, + }; + stream.state.errors.push(error(range)); + Ok(None) + } + Err(err) => Err(err), + } + } +} + fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { (Token::DotPragma, Token::String, Token::Semicolon) .void() @@ -746,7 +875,7 @@ fn array_initializer<'a, 'input: 'a>( } delimited( Token::LBrace, - separated( + separated::<_, (), (), _, _, _, _>( 0..=array_dimensions[0] as usize, single_value_append(&mut result, type_), Token::Comma, @@ -935,7 +1064,7 @@ fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Opti } fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> { - any.verify_map(|t| { + any.verify_map(|(t, _)| { Some(match t { Token::DotS8 => ScalarType::S8, Token::DotS16 => ScalarType::S16, @@ -1001,8 +1130,8 @@ fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<() ident_literal("function_name"), ident, dispatch! { any; - Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), - Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + (Token::Comma, _) => (ident_literal("inlined_at"), u32, u32, u32).void(), + (Token::Plus, _) => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), _ => fail }, )), @@ -1033,13 +1162,14 @@ fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>( fn ident_literal< 'a, 'input, - I: Stream<Token = Token<'input>> + StreamIsPartial, + X, + I: Stream<Token = (Token<'input>, X)> + StreamIsPartial, E: ParserError<I>, >( s: &'input str, ) -> impl Parser<I, (), E> + 'input { move |stream: &mut I| { - any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + any.verify(|(t, _)| matches!(t, Token::Ident(text) if *text == s)) .void() .parse_next(stream) } @@ -1086,8 +1216,8 @@ impl<Ident> ast::ParsedOperand<Ident> { let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; // TODO: parse .v8 literals dispatch! {any; - Token::RBrace => empty.map(|_| vec![r1, r2]), - Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + (Token::RBrace, _) => empty.map(|_| vec![r1, r2]), + (Token::Comma, _) => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), _ => fail } .parse_next(stream) @@ -1102,7 +1232,7 @@ impl<Ident> ast::ParsedOperand<Ident> { } #[derive(Debug, thiserror::Error)] -pub enum PtxError { +pub enum PtxError<'input> { #[error("{source}")] ParseInt { #[from] @@ -1146,10 +1276,10 @@ pub enum PtxError { ArrayInitalizer, #[error("")] NonExternPointer, - #[error("{start}:{end}")] - UnrecognizedStatement { start: usize, end: usize }, - #[error("{start}:{end}")] - UnrecognizedDirective { start: usize, end: usize }, + #[error("{0:?}")] + UnrecognizedStatement(Option<&'input str>), + #[error("{0:?}")] + UnrecognizedDirective(Option<&'input str>), } #[derive(Debug)] @@ -1244,11 +1374,11 @@ impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { } } -impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E> - for Token<'input> +impl<'input, X, I: Stream<Token = (Self, X)> + StreamIsPartial, E: ParserError<I>> + Parser<I, (Self, X), E> for Token<'input> { - fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> { - any.verify(|t| t == self).parse_next(input) + fn parse_next(&mut self, input: &mut I) -> PResult<(Self, X), E> { + any.verify(|(t, _)| t == self).parse_next(input) } } @@ -1257,7 +1387,7 @@ fn bra<'a, 'input>( ) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { preceded( opt(Token::DotUni), - any.verify_map(|t| match t { + any.verify_map(|(t, _)| match t { Token::Ident(ident) => Some(ast::Instruction::Bra { arguments: BraArgs { src: ident }, }), @@ -3224,21 +3354,27 @@ derive_parser!( #[cfg(test)] mod tests { + use crate::parse_module_checked; + use crate::PtxError; + use super::target; use super::PtxParserState; use super::Token; use logos::Logos; + use logos::Span; use winnow::prelude::*; #[test] fn sm_11() { - let tokens = Token::lexer(".target sm_11") + let text = ".target sm_11"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::<Result<Vec<_>, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); assert_eq!(errors.len(), 0); @@ -3246,13 +3382,15 @@ mod tests { #[test] fn sm_90a() { - let tokens = Token::lexer(".target sm_90a") + let text = ".target sm_90a"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::<Result<Vec<_>, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); assert_eq!(errors.len(), 0); @@ -3260,15 +3398,97 @@ mod tests { #[test] fn sm_90ab() { - let tokens = Token::lexer(".target sm_90ab") + let text = ".target sm_90ab"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) .collect::<Result<Vec<_>, _>>() .unwrap(); let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(&mut errors), + state: PtxParserState::new(text, &mut errors), }; assert!(target.parse(stream).is_err()); assert_eq!(errors.len(), 0); } + + #[test] + fn report_unknown_intruction() { + let text = " + .version 6.5 + .target sm_30 + .address_size 64 + + .visible .entry add( + .param .u64 input, + .param .u64 output + ) + { + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + unknown_op1.asdf foobar; + add.u64 temp2, temp, 1; + unknown_op2 temp2, temp; + st.u64 [out_addr], temp2; + ret; + }"; + let errors = parse_module_checked(text).err().unwrap(); + assert_eq!(errors.len(), 2); + assert!(matches!( + errors[0], + PtxError::UnrecognizedStatement(Some("unknown_op1.asdf foobar;")) + )); + assert!(matches!( + errors[1], + PtxError::UnrecognizedStatement(Some("unknown_op2 temp2, temp;")) + )); + } + + #[test] + fn report_unknown_directive() { + let text = " + .version 6.5 + .target sm_30 + .address_size 64 + + .broken_directive_fail; 34; { + + .visible .entry add( + .param .u64 input, + .param .u64 output + ) + { + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; + } + + section foobar }"; + let errors = parse_module_checked(text).err().unwrap(); + assert_eq!(errors.len(), 2); + assert!(matches!( + errors[0], + PtxError::UnrecognizedDirective(Some(".broken_directive_fail; 34; {")) + )); + assert!(matches!( + errors[1], + PtxError::UnrecognizedDirective(Some("section foobar }")) + )); + } } |