aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-04-12 20:50:34 +0200
committerAndrzej Janik <[email protected]>2020-04-12 20:50:34 +0200
commitbbe993392b803f01effc0f86da861d348741d1eb (patch)
treedf9019efcf4645ab0a38cab7e6d5eb59d97a4e6c /ptx
parentb8129aab20c4768ffc3a304ea957c3ec278471dc (diff)
downloadZLUDA-bbe993392b803f01effc0f86da861d348741d1eb.tar.gz
ZLUDA-bbe993392b803f01effc0f86da861d348741d1eb.zip
Add better error handling during ast construction
Diffstat (limited to 'ptx')
-rw-r--r--ptx/Cargo.toml2
-rw-r--r--ptx/doc/NOTES.md18
-rw-r--r--ptx/src/ast.rs158
-rw-r--r--ptx/src/lib.rs12
-rw-r--r--ptx/src/ptx.lalrpop69
-rw-r--r--ptx/src/spirv.rs9
-rw-r--r--ptx/src/test/mod.rs14
-rw-r--r--ptx/src/translate.rs107
8 files changed, 330 insertions, 59 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 64c3547..7c8a701 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -10,6 +10,8 @@ edition = "2018"
lalrpop-util = "0.18.1"
regex = "1"
rspirv = "0.6"
+spirv_headers = "1.4"
+quick-error = "1.2"
[build-dependencies.lalrpop]
version = "0.18.1"
diff --git a/ptx/doc/NOTES.md b/ptx/doc/NOTES.md
deleted file mode 100644
index b4d2ad3..0000000
--- a/ptx/doc/NOTES.md
+++ /dev/null
@@ -1,18 +0,0 @@
-I'm convinced nobody actually uses parser generators in Rust:
-* pomelo can't generate lexer (understandable, as it is a port of lemon and lemon can't do this either)
-* pest can't do parse actions, you have to convert your parse tree to ast manually
-* lalrpop can't do comments
- * and the day I wrote the line above it can
- * reports parsing errors as byte offsets
- * if you want to skip parsing one of the alternatives functional design gets quite awkward
-* antlr4rust is untried and requires java to build
-* no library supports island grammars
-
-What to emit?
-* SPIR-V
- * Better library support, easier to emit
- * Can by optimized by IGC
- * Can't do some things (not sure what exactly yet)
- * But we can work around things with inline VISA
-* VISA
- * Quicker compilation \ No newline at end of file
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 645e3a1..3bb142d 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,28 +1,168 @@
+use std::convert::From;
+use std::error::Error;
+use std::mem;
+use std::num::ParseIntError;
+
+quick_error! {
+ #[derive(Debug)]
+ pub enum PtxError {
+ Parse (err: ParseIntError) {
+ display("{}", err)
+ cause(err)
+ from()
+ }
+ }
+}
+
+pub struct WithErrors<T, E> {
+ pub value: T,
+ pub errors: Vec<E>,
+}
+
+impl<T, E> WithErrors<T, E> {
+ pub fn new(t: T) -> Self {
+ WithErrors {
+ value: t,
+ errors: Vec::new(),
+ }
+ }
+
+ pub fn map<F: FnOnce(T) -> U, U>(self, f: F) -> WithErrors<U, E> {
+ WithErrors {
+ value: f(self.value),
+ errors: self.errors,
+ }
+ }
+
+ pub fn map2<X, Y, F: FnOnce(X, Y) -> T>(
+ x: WithErrors<X, E>,
+ y: WithErrors<Y, E>,
+ f: F,
+ ) -> Self {
+ let mut errors = x.errors;
+ let mut errors_other = y.errors;
+ if errors.len() < errors_other.len() {
+ mem::swap(&mut errors, &mut errors_other);
+ }
+ errors.extend(errors_other);
+ WithErrors {
+ value: f(x.value, y.value),
+ errors: errors,
+ }
+ }
+}
+
+impl<T:Default, E: Error> WithErrors<T, E> {
+ pub fn from_results<X: Default, Y: Default, F: FnOnce(X, Y) -> T>(
+ x: Result<X, E>,
+ y: Result<Y, E>,
+ f: F,
+ ) -> Self {
+ match (x, y) {
+ (Ok(x), Ok(y)) => WithErrors {
+ value: f(x, y),
+ errors: Vec::new(),
+ },
+ (Err(e), Ok(y)) => WithErrors {
+ value: f(X::default(), y),
+ errors: vec![e],
+ },
+ (Ok(x), Err(e)) => WithErrors {
+ value: f(x, Y::default()),
+ errors: vec![e],
+ },
+ (Err(e1), Err(e2)) => WithErrors {
+ value: T::default(),
+ errors: vec![e1, e2],
+ },
+ }
+ }
+}
+
+impl<T, E: Error> WithErrors<Vec<T>, E> {
+ pub fn from_vec(v: Vec<WithErrors<T, E>>) -> Self {
+ let mut values = Vec::with_capacity(v.len());
+ let mut errors = Vec::new();
+ for we in v.into_iter() {
+ values.push(we.value);
+ errors.extend(we.errors);
+ }
+ WithErrors {
+ value: values,
+ errors: errors,
+ }
+ }
+}
+
+pub trait WithErrorsExt<From, To, E> {
+ fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E>;
+}
+
+impl<From, To: Default, E> WithErrorsExt<From, To, E> for Result<From, E> {
+ fn with_errors<F: FnOnce(From) -> To>(self, f: F) -> WithErrors<To, E> {
+ self.map_or_else(
+ |e| WithErrors {
+ value: To::default(),
+ errors: vec![e],
+ },
+ |t| WithErrors {
+ value: f(t),
+ errors: Vec::new(),
+ },
+ )
+ }
+}
+
pub struct Module<'a> {
pub version: (u8, u8),
- pub functions: Vec<Function<'a>>
+ pub functions: Vec<Function<'a>>,
}
pub struct Function<'a> {
pub kernel: bool,
pub name: &'a str,
- pub args: Vec<Argument>,
+ pub args: Vec<Argument<'a>>,
pub body: Vec<Statement<'a>>,
}
-pub struct Argument {
+#[derive(Default)]
+pub struct Argument<'a> {
+ pub name: &'a str,
+ pub a_type: ScalarType,
+ pub length: u32,
+}
+
+pub enum ScalarType {
+ B8,
+ B16,
+ B32,
+ B64,
+ U8,
+ U16,
+ U32,
+ U64,
+ S8,
+ S16,
+ S32,
+ S64,
+ F16,
+ F32,
+ F64,
+}
+impl Default for ScalarType {
+ fn default() -> Self {
+ ScalarType::B8
+ }
}
pub enum Statement<'a> {
Label(&'a str),
Variable(Variable),
- Instruction(Instruction)
+ Instruction(Instruction),
}
-pub struct Variable {
-
-}
+pub struct Variable {}
pub enum Instruction {
Ld,
@@ -35,5 +175,5 @@ pub enum Instruction {
Cvt,
Shl,
At,
- Ret
-} \ No newline at end of file
+ Ret,
+}
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index a29270f..716a25c 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -1,11 +1,19 @@
#[macro_use]
+extern crate quick_error;
+#[macro_use]
extern crate lalrpop_util;
+extern crate rspirv;
+extern crate spirv_headers as spirv;
lalrpop_mod!(ptx);
mod test;
-mod spirv;
+mod translate;
pub mod ast;
pub use ast::Module as Module;
-pub use spirv::translate as to_spirv;
+pub use translate::to_spirv as to_spirv;
+
+pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
+ x.into_iter().filter_map(|x| x).collect()
+} \ No newline at end of file
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 96288b9..b646d68 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1,6 +1,6 @@
-use std::str::FromStr;
use crate::ast;
-use std::convert::identity;
+use crate::ast::{WithErrors, WithErrorsExt};
+use crate::without_none;
grammar;
@@ -16,19 +16,23 @@ match {
_
}
-pub Module: Option<ast::Module<'input>> = {
- <v:Version> Target <f:Directive*> => v.map(|v| ast::Module { version: v, functions: f.into_iter().filter_map(identity).collect::<Vec<_>>() })
+pub Module: WithErrors<ast::Module<'input>, ast::PtxError> = {
+ <v:Version> Target <f:Directive*> => {
+ let funcs = WithErrors::from_vec(without_none(f));
+ WithErrors::map2(v, funcs,
+ |v, funcs| ast::Module { version: v, functions: funcs }
+ )
+ }
};
-Version: Option<(u8, u8)> = {
+Version: WithErrors<(u8, u8), ast::PtxError> = {
".version" <v:VersionNumber> => {
let dot = v.find('.').unwrap();
- let major = v[..dot].parse::<u8>();
- major.ok().and_then(|major| {
- v[dot+1..].parse::<u8>().ok().map(|minor| {
- (major, minor)
- })
- })
+ let major = v[..dot].parse::<u8>().map_err(Into::into);
+ let minor = v[dot+1..].parse::<u8>().map_err(Into::into);
+ WithErrors::from_results(major, minor,
+ |major, minor| (major, minor)
+ )
}
}
@@ -45,7 +49,7 @@ TargetSpecifier = {
"map_f64_to_f32"
};
-Directive: Option<ast::Function<'input>> = {
+Directive: Option<WithErrors<ast::Function<'input>, ast::PtxError>> = {
AddressSize => None,
<f:Function> => Some(f),
File => None,
@@ -56,8 +60,11 @@ AddressSize = {
".address_size" Num
};
-Function: ast::Function<'input> = {
- LinkingDirective* <kernel:IsKernel> <name:ID> "(" <args:Comma<FunctionInput>> ")" <body:FunctionBody> => ast::Function {<>}
+Function: WithErrors<ast::Function<'input>, ast::PtxError> = {
+ LinkingDirective* <k:IsKernel> <n:ID> "(" <args:Comma<FunctionInput>> ")" <b:FunctionBody> => {
+ WithErrors::from_vec(args)
+ .map(|args| ast::Function{kernel: k, name: n, args: args, body: b})
+ }
};
LinkingDirective = {
@@ -71,12 +78,21 @@ IsKernel: bool = {
".func" => false
};
-FunctionInput: ast::Argument = {
- ".param" Type ID => ast::Argument {}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
+FunctionInput: WithErrors<ast::Argument<'input>, ast::PtxError> = {
+ ".param" <_type:ScalarType> <name:ID> => {
+ WithErrors::new(ast::Argument {a_type: _type, name: name, length: 1 })
+ },
+ ".param" <a_type:ScalarType> <name:ID> "[" <length:Num> "]" => {
+ let length = length.parse::<u32>().map_err(Into::into);
+ length.with_errors(
+ |l| ast::Argument { a_type: a_type, name: name, length: l }
+ )
+ }
};
FunctionBody: Vec<ast::Statement<'input>> = {
- "{" <s:Statement*> "}" => { s.into_iter().filter_map(identity).collect() }
+ "{" <s:Statement*> "}" => { without_none(s) }
};
StateSpaceSpecifier = {
@@ -88,6 +104,25 @@ StateSpaceSpecifier = {
".shared"
};
+ScalarType: ast::ScalarType = {
+ ".b8" => ast::ScalarType::B8,
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u8" => ast::ScalarType::U8,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s8" => ast::ScalarType::S8,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f16" => ast::ScalarType::F16,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
+};
+
+
Type = {
BaseType,
".pred",
diff --git a/ptx/src/spirv.rs b/ptx/src/spirv.rs
deleted file mode 100644
index a43692d..0000000
--- a/ptx/src/spirv.rs
+++ /dev/null
@@ -1,9 +0,0 @@
-use super::ast;
-
-pub struct TranslateError {
-
-}
-
-pub fn translate(ast: ast::Module) -> Result<Vec<u32>, TranslateError> {
- Ok(vec!())
-} \ No newline at end of file
diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs
index 0198900..9a07271 100644
--- a/ptx/src/test/mod.rs
+++ b/ptx/src/test/mod.rs
@@ -1,15 +1,21 @@
use super::ptx;
+fn parse_and_assert(s: &str) {
+ assert!(
+ ptx::ModuleParser::new()
+ .parse(s)
+ .unwrap()
+ .errors
+ .len() == 0);
+}
#[test]
fn empty() {
- assert!(ptx::ModuleParser::new().parse(
- ".version 6.5 .target sm_30, debug")
- .unwrap() == ());
+ parse_and_assert(".version 6.5 .target sm_30, debug");
}
#[test]
fn vector_add() {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
- assert!(ptx::ModuleParser::new().parse(vector_add).unwrap() == ());
+ parse_and_assert(vector_add);
} \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
new file mode 100644
index 0000000..6039c55
--- /dev/null
+++ b/ptx/src/translate.rs
@@ -0,0 +1,107 @@
+use crate::ast;
+use rspirv::dr;
+use std::collections::HashMap;
+
+pub struct TranslationError {
+
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum SpirvType {
+ Base(BaseType),
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum BaseType {
+ Int8,
+ Int16,
+ Int32,
+ Int64,
+ Uint8,
+ Uint16,
+ Uint32,
+ Uint64,
+ Float16,
+ Float32,
+ Float64,
+}
+
+struct TypeWordMap {
+ void: spirv::Word,
+ fn_void: spirv::Word,
+ complex: HashMap<SpirvType, spirv::Word>
+}
+
+impl TypeWordMap {
+ fn new(b: &mut dr::Builder) -> TypeWordMap {
+ let void = b.type_void();
+ TypeWordMap {
+ void: void,
+ fn_void: b.type_function(void, vec![]),
+ complex: HashMap::<SpirvType, spirv::Word>::new()
+ }
+ }
+
+ fn void(&self) -> spirv::Word { self.void }
+ fn fn_void(&self) -> spirv::Word { self.fn_void }
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
+ *self.complex.entry(t).or_insert_with(|| {
+ match t {
+ SpirvType::Base(BaseType::Int8) => b.type_int(8, 1),
+ SpirvType::Base(BaseType::Int16) => b.type_int(16, 1),
+ SpirvType::Base(BaseType::Int32) => b.type_int(32, 1),
+ SpirvType::Base(BaseType::Int64) => b.type_int(64, 1),
+ SpirvType::Base(BaseType::Uint8) => b.type_int(8, 0),
+ SpirvType::Base(BaseType::Uint16) => b.type_int(16, 0),
+ SpirvType::Base(BaseType::Uint32) => b.type_int(32, 0),
+ SpirvType::Base(BaseType::Uint64) => b.type_int(64, 0),
+ SpirvType::Base(BaseType::Float16) => b.type_float(16),
+ SpirvType::Base(BaseType::Float32) => b.type_float(32),
+ SpirvType::Base(BaseType::Float64) => b.type_float(64),
+ }
+ })
+ }
+}
+
+pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, TranslationError> {
+ let mut builder = dr::Builder::new();
+ // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
+ builder.set_version(1, 0);
+ emit_capabilities(&mut builder);
+ emit_extensions(&mut builder);
+ emit_extended_instruction_sets(&mut builder);
+ emit_memory_model(&mut builder);
+ let mut map = TypeWordMap::new(&mut builder);
+ for f in ast.functions {
+ emit_function(&mut builder, &mut map, &f);
+ }
+ Ok(vec!())
+}
+
+fn emit_capabilities(builder: &mut dr::Builder) {
+ builder.capability(spirv::Capability::Linkage);
+ builder.capability(spirv::Capability::Addresses);
+ builder.capability(spirv::Capability::Kernel);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Int8);
+}
+
+fn emit_extensions(_: &mut dr::Builder) {
+
+}
+
+fn emit_extended_instruction_sets(builder: &mut dr::Builder) {
+ builder.ext_inst_import("OpenCL.std");
+}
+
+fn emit_memory_model(builder: &mut dr::Builder) {
+ builder.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL);
+}
+
+fn emit_function(builder: &mut dr::Builder, map: &TypeWordMap, f: &ast::Function) {
+ let func_id = builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, map.fn_void());
+
+ builder.ret();
+ builder.end_function();
+} \ No newline at end of file