diff options
author | Andrzej Janik <[email protected]> | 2020-04-12 20:50:34 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-04-12 20:50:34 +0200 |
commit | bbe993392b803f01effc0f86da861d348741d1eb (patch) | |
tree | df9019efcf4645ab0a38cab7e6d5eb59d97a4e6c /ptx | |
parent | b8129aab20c4768ffc3a304ea957c3ec278471dc (diff) | |
download | ZLUDA-bbe993392b803f01effc0f86da861d348741d1eb.tar.gz ZLUDA-bbe993392b803f01effc0f86da861d348741d1eb.zip |
Add better error handling during ast construction
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/Cargo.toml | 2 | ||||
-rw-r--r-- | ptx/doc/NOTES.md | 18 | ||||
-rw-r--r-- | ptx/src/ast.rs | 158 | ||||
-rw-r--r-- | ptx/src/lib.rs | 12 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 69 | ||||
-rw-r--r-- | ptx/src/spirv.rs | 9 | ||||
-rw-r--r-- | ptx/src/test/mod.rs | 14 | ||||
-rw-r--r-- | ptx/src/translate.rs | 107 |
8 files changed, 330 insertions, 59 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 64c3547..7c8a701 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -10,6 +10,8 @@ edition = "2018" lalrpop-util = "0.18.1" regex = "1" rspirv = "0.6" +spirv_headers = "1.4" +quick-error = "1.2" [build-dependencies.lalrpop] version = "0.18.1" diff --git a/ptx/doc/NOTES.md b/ptx/doc/NOTES.md deleted file mode 100644 index b4d2ad3..0000000 --- a/ptx/doc/NOTES.md +++ /dev/null @@ -1,18 +0,0 @@ -I'm convinced nobody actually uses parser generators in Rust: -* pomelo can't generate lexer (understandable, as it is a port of lemon and lemon can't do this either) -* pest can't do parse actions, you have to convert your parse tree to ast manually -* lalrpop can't do comments - * and the day I wrote the line above it can - * reports parsing errors as byte offsets - * if you want to skip parsing one of the alternatives functional design gets quite awkward -* antlr4rust is untried and requires java to build -* no library supports island grammars - -What to emit? -* SPIR-V - * Better library support, easier to emit - * Can by optimized by IGC - * Can't do some things (not sure what exactly yet) - * But we can work around things with inline VISA -* VISA - * Quicker compilation
\ No newline at end of file diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 645e3a1..3bb142d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,28 +1,168 @@ +use std::convert::From; +use std::error::Error; +use std::mem; +use std::num::ParseIntError; + +quick_error! { + #[derive(Debug)] + pub enum PtxError { + Parse (err: ParseIntError) { + display("{}", err) + cause(err) + from() + } + } +} + +pub struct WithErrors<T, E> { + pub value: T, + pub errors: Vec<E>, +} + +impl<T, E> WithErrors<T, E> { + pub fn new(t: T) -> Self { + WithErrors { + value: t, + errors: Vec::new(), + } + } + + pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> WithErrors<U, E> { + WithErrors { + value: f(self.value), + errors: self.errors, + } + } + + pub fn map2<X, Y, F: FnOnce(X, Y) -> T>( + x: WithErrors<X, E>, + y: WithErrors<Y, E>, + f: F, + ) -> Self { + let mut errors = x.errors; + let mut errors_other = y.errors; + if errors.len() < errors_other.len() { + mem::swap(&mut errors, &mut errors_other); + } + errors.extend(errors_other); + WithErrors { + value: f(x.value, y.value), + errors: errors, + } + } +} + +impl<T:Default, E: Error> WithErrors<T, E> { + pub fn from_results<X: Default, Y: Default, F: FnOnce(X, Y) -> T>( + x: Result<X, E>, + y: Result<Y, E>, + f: F, + ) -> Self { + match (x, y) { + (Ok(x), Ok(y)) => WithErrors { + value: f(x, y), + errors: Vec::new(), + }, + (Err(e), Ok(y)) => WithErrors { + value: f(X::default(), y), + errors: vec![e], + }, + (Ok(x), Err(e)) => WithErrors { + value: f(x, Y::default()), + errors: vec![e], + }, + (Err(e1), Err(e2)) => WithErrors { + value: T::default(), + errors: vec![e1, e2], + }, + } + } +} + +impl<T, E: Error> WithErrors<Vec<T>, E> { + pub fn from_vec(v: Vec<WithErrors<T, E>>) -> Self { + let mut values = Vec::with_capacity(v.len()); + let mut errors = Vec::new(); + for we in v.into_iter() { + values.push(we.value); + errors.extend(we.errors); + } + WithErrors { + value: values, + errors: errors, + } + } +} + +pub trait WithErrorsExt<From, To, E> { + fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E>; +} + +impl<From, To: Default, E> WithErrorsExt<From, To, E> for Result<From, E> { + fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E> { + self.map_or_else( + |e| WithErrors { + value: To::default(), + errors: vec![e], + }, + |t| WithErrors { + value: f(t), + errors: Vec::new(), + }, + ) + } +} + pub struct Module<'a> { pub version: (u8, u8), - pub functions: Vec<Function<'a>> + pub functions: Vec<Function<'a>>, } pub struct Function<'a> { pub kernel: bool, pub name: &'a str, - pub args: Vec<Argument>, + pub args: Vec<Argument<'a>>, pub body: Vec<Statement<'a>>, } -pub struct Argument { +#[derive(Default)] +pub struct Argument<'a> { + pub name: &'a str, + pub a_type: ScalarType, + pub length: u32, +} + +pub enum ScalarType { + B8, + B16, + B32, + B64, + U8, + U16, + U32, + U64, + S8, + S16, + S32, + S64, + F16, + F32, + F64, +} +impl Default for ScalarType { + fn default() -> Self { + ScalarType::B8 + } } pub enum Statement<'a> { Label(&'a str), Variable(Variable), - Instruction(Instruction) + Instruction(Instruction), } -pub struct Variable { - -} +pub struct Variable {} pub enum Instruction { Ld, @@ -35,5 +175,5 @@ pub enum Instruction { Cvt, Shl, At, - Ret -}
\ No newline at end of file + Ret, +} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index a29270f..716a25c 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -1,11 +1,19 @@ #[macro_use] +extern crate quick_error; +#[macro_use] extern crate lalrpop_util; +extern crate rspirv; +extern crate spirv_headers as spirv; lalrpop_mod!(ptx); mod test; -mod spirv; +mod translate; pub mod ast; pub use ast::Module as Module; -pub use spirv::translate as to_spirv; +pub use translate::to_spirv as to_spirv; + +pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> { + x.into_iter().filter_map(|x| x).collect() +}
\ No newline at end of file diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 96288b9..b646d68 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1,6 +1,6 @@ -use std::str::FromStr; use crate::ast; -use std::convert::identity; +use crate::ast::{WithErrors, WithErrorsExt}; +use crate::without_none; grammar; @@ -16,19 +16,23 @@ match { _ } -pub Module: Option<ast::Module<'input>> = { - <v:Version> Target <f:Directive*> => v.map(|v| ast::Module { version: v, functions: f.into_iter().filter_map(identity).collect::<Vec<_>>() }) +pub Module: WithErrors<ast::Module<'input>, ast::PtxError> = { + <v:Version> Target <f:Directive*> => { + let funcs = WithErrors::from_vec(without_none(f)); + WithErrors::map2(v, funcs, + |v, funcs| ast::Module { version: v, functions: funcs } + ) + } }; -Version: Option<(u8, u8)> = { +Version: WithErrors<(u8, u8), ast::PtxError> = { ".version" <v:VersionNumber> => { let dot = v.find('.').unwrap(); - let major = v[..dot].parse::<u8>(); - major.ok().and_then(|major| { - v[dot+1..].parse::<u8>().ok().map(|minor| { - (major, minor) - }) - }) + let major = v[..dot].parse::<u8>().map_err(Into::into); + let minor = v[dot+1..].parse::<u8>().map_err(Into::into); + WithErrors::from_results(major, minor, + |major, minor| (major, minor) + ) } } @@ -45,7 +49,7 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option<ast::Function<'input>> = { +Directive: Option<WithErrors<ast::Function<'input>, ast::PtxError>> = { AddressSize => None, <f:Function> => Some(f), File => None, @@ -56,8 +60,11 @@ AddressSize = { ".address_size" Num }; -Function: ast::Function<'input> = { - LinkingDirective* <kernel:IsKernel> <name:ID> "(" <args:Comma<FunctionInput>> ")" <body:FunctionBody> => ast::Function {<>} +Function: WithErrors<ast::Function<'input>, ast::PtxError> = { + LinkingDirective* <k:IsKernel> <n:ID> "(" <args:Comma<FunctionInput>> ")" <b:FunctionBody> => { + WithErrors::from_vec(args) + .map(|args| ast::Function{kernel: k, name: n, args: args, body: b}) + } }; LinkingDirective = { @@ -71,12 +78,21 @@ IsKernel: bool = { ".func" => false }; -FunctionInput: ast::Argument = { - ".param" Type ID => ast::Argument {} +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space +FunctionInput: WithErrors<ast::Argument<'input>, ast::PtxError> = { + ".param" <_type:ScalarType> <name:ID> => { + WithErrors::new(ast::Argument {a_type: _type, name: name, length: 1 }) + }, + ".param" <a_type:ScalarType> <name:ID> "[" <length:Num> "]" => { + let length = length.parse::<u32>().map_err(Into::into); + length.with_errors( + |l| ast::Argument { a_type: a_type, name: name, length: l } + ) + } }; FunctionBody: Vec<ast::Statement<'input>> = { - "{" <s:Statement*> "}" => { s.into_iter().filter_map(identity).collect() } + "{" <s:Statement*> "}" => { without_none(s) } }; StateSpaceSpecifier = { @@ -88,6 +104,25 @@ StateSpaceSpecifier = { ".shared" }; +ScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, +}; + + Type = { BaseType, ".pred", diff --git a/ptx/src/spirv.rs b/ptx/src/spirv.rs deleted file mode 100644 index a43692d..0000000 --- a/ptx/src/spirv.rs +++ /dev/null @@ -1,9 +0,0 @@ -use super::ast;
-
-pub struct TranslateError {
-
-}
-
-pub fn translate(ast: ast::Module) -> Result<Vec<u32>, TranslateError> {
- Ok(vec!())
-}
\ No newline at end of file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 0198900..9a07271 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,15 +1,21 @@ use super::ptx; +fn parse_and_assert(s: &str) { + assert!( + ptx::ModuleParser::new() + .parse(s) + .unwrap() + .errors + .len() == 0); +} #[test] fn empty() { - assert!(ptx::ModuleParser::new().parse( - ".version 6.5 .target sm_30, debug") - .unwrap() == ()); + parse_and_assert(".version 6.5 .target sm_30, debug"); } #[test] fn vector_add() { let vector_add = include_str!("vectorAdd_kernel64.ptx"); - assert!(ptx::ModuleParser::new().parse(vector_add).unwrap() == ()); + parse_and_assert(vector_add); }
\ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs new file mode 100644 index 0000000..6039c55 --- /dev/null +++ b/ptx/src/translate.rs @@ -0,0 +1,107 @@ +use crate::ast;
+use rspirv::dr;
+use std::collections::HashMap;
+
+pub struct TranslationError {
+
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum SpirvType {
+ Base(BaseType),
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum BaseType {
+ Int8,
+ Int16,
+ Int32,
+ Int64,
+ Uint8,
+ Uint16,
+ Uint32,
+ Uint64,
+ Float16,
+ Float32,
+ Float64,
+}
+
+struct TypeWordMap {
+ void: spirv::Word,
+ fn_void: spirv::Word,
+ complex: HashMap<SpirvType, spirv::Word>
+}
+
+impl TypeWordMap {
+ fn new(b: &mut dr::Builder) -> TypeWordMap {
+ let void = b.type_void();
+ TypeWordMap {
+ void: void,
+ fn_void: b.type_function(void, vec![]),
+ complex: HashMap::<SpirvType, spirv::Word>::new()
+ }
+ }
+
+ fn void(&self) -> spirv::Word { self.void }
+ fn fn_void(&self) -> spirv::Word { self.fn_void }
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
+ *self.complex.entry(t).or_insert_with(|| {
+ match t {
+ SpirvType::Base(BaseType::Int8) => b.type_int(8, 1),
+ SpirvType::Base(BaseType::Int16) => b.type_int(16, 1),
+ SpirvType::Base(BaseType::Int32) => b.type_int(32, 1),
+ SpirvType::Base(BaseType::Int64) => b.type_int(64, 1),
+ SpirvType::Base(BaseType::Uint8) => b.type_int(8, 0),
+ SpirvType::Base(BaseType::Uint16) => b.type_int(16, 0),
+ SpirvType::Base(BaseType::Uint32) => b.type_int(32, 0),
+ SpirvType::Base(BaseType::Uint64) => b.type_int(64, 0),
+ SpirvType::Base(BaseType::Float16) => b.type_float(16),
+ SpirvType::Base(BaseType::Float32) => b.type_float(32),
+ SpirvType::Base(BaseType::Float64) => b.type_float(64),
+ }
+ })
+ }
+}
+
+pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
+ let mut builder = dr::Builder::new();
+ // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
+ builder.set_version(1, 0);
+ emit_capabilities(&mut builder);
+ emit_extensions(&mut builder);
+ emit_extended_instruction_sets(&mut builder);
+ emit_memory_model(&mut builder);
+ let mut map = TypeWordMap::new(&mut builder);
+ for f in ast.functions {
+ emit_function(&mut builder, &mut map, &f);
+ }
+ Ok(vec!())
+}
+
+fn emit_capabilities(builder: &mut dr::Builder) {
+ builder.capability(spirv::Capability::Linkage);
+ builder.capability(spirv::Capability::Addresses);
+ builder.capability(spirv::Capability::Kernel);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Int8);
+}
+
+fn emit_extensions(_: &mut dr::Builder) {
+
+}
+
+fn emit_extended_instruction_sets(builder: &mut dr::Builder) {
+ builder.ext_inst_import("OpenCL.std");
+}
+
+fn emit_memory_model(builder: &mut dr::Builder) {
+ builder.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL);
+}
+
+fn emit_function(builder: &mut dr::Builder, map: &TypeWordMap, f: &ast::Function) {
+ let func_id = builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, map.fn_void());
+
+ builder.ret();
+ builder.end_function();
+}
\ No newline at end of file |