aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/lib.rs')
-rw-r--r--ptx/src/lib.rs112
1 files changed, 112 insertions, 0 deletions
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 4ade4e8..db9fc23 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -28,6 +28,8 @@ pub mod ast;
mod test;
mod translate;
+use std::fmt;
+
pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
@@ -36,6 +38,86 @@ pub use translate::to_spirv_module;
pub use translate::KernelInfo;
pub use translate::TranslateError;
+pub trait ModuleParserExt {
+ fn parse_checked<'input>(
+ txt: &'input str,
+ ) -> Result<ast::Module<'input>, Vec<ParseError<usize, Token<'input>, ast::PtxError>>>;
+
+ // Returned AST might be malformed. Some users, like logger, want to look at
+ // malformed AST to record information - list of kernels or such
+ fn parse_unchecked<'input>(
+ txt: &'input str,
+ ) -> (
+ ast::Module<'input>,
+ Vec<ParseError<usize, Token<'input>, ast::PtxError>>,
+ );
+}
+
+impl ModuleParserExt for ModuleParser {
+ fn parse_checked<'input>(
+ txt: &'input str,
+ ) -> Result<ast::Module<'input>, Vec<ParseError<usize, Token<'input>, ast::PtxError>>> {
+ let mut errors = Vec::new();
+ let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt);
+ match (&*errors, maybe_ast) {
+ (&[], Ok(ast)) => Ok(ast),
+ (_, Err(unrecoverable)) => {
+ errors.push(unrecoverable);
+ Err(errors)
+ }
+ (_, Ok(_)) => Err(errors),
+ }
+ }
+
+ fn parse_unchecked<'input>(
+ txt: &'input str,
+ ) -> (
+ ast::Module<'input>,
+ Vec<ParseError<usize, Token<'input>, ast::PtxError>>,
+ ) {
+ let mut errors = Vec::new();
+ let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt);
+ let ast = match maybe_ast {
+ Ok(ast) => ast,
+ Err(unrecoverable_err) => {
+ errors.push(unrecoverable_err);
+ ast::Module {
+ version: (0, 0),
+ directives: Vec::new(),
+ }
+ }
+ };
+ (ast, errors)
+ }
+}
+
+pub struct DisplayParseError<'a, Loc, Tok, Err>(pub &'a str, pub &'a ParseError<Loc, Tok, Err>);
+
+impl<'a, Loc, Tok, Err> fmt::Display for DisplayParseError<'a, Loc, Tok, Err>
+where
+ Loc: fmt::Display + Into<usize> + Copy,
+ Tok: fmt::Display,
+ Err: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self.1 {
+ ParseError::UnrecognizedToken {
+ token: (start, token, end),
+ ..
+ } => {
+ let full_instruction =
+ unsafe { self.0.get_unchecked((*start).into()..(*end).into()) };
+ write!(
+ f,
+ "`{}` unrecognized token `{}` found at {}:{}",
+ full_instruction, token, start, end
+ )
+ }
+ _ => self.fmt(f),
+ }
+ }
+}
+
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()
}
@@ -53,3 +135,33 @@ pub(crate) fn vector_index<'input>(
}),
}
}
+
+#[cfg(test)]
+mod tests {
+ use crate::{DisplayParseError, ModuleParser, ModuleParserExt};
+
+ #[test]
+ fn error_report_unknown_instructions() {
+ let module = r#"
+ .version 6.5
+ .target sm_30
+ .address_size 64
+
+ .visible .entry add(
+ .param .u64 input,
+ )
+ {
+ .reg .u64 x;
+ does_not_exist.u64 x, x;
+ ret;
+ }"#;
+ let errors = match ModuleParser::parse_checked(module) {
+ Err(e) => e,
+ Ok(_) => panic!(),
+ };
+ assert_eq!(errors.len(), 1);
+ let reporter = DisplayParseError(module, &errors[0]);
+ let build_log_string = format!("{}", reporter);
+ assert!(build_log_string.contains("does_not_exist"));
+ }
+}