aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-02 02:03:13 +0200
committerAndrzej Janik <[email protected]>2020-09-02 02:03:13 +0200
commit87cc72494ec492d60011933b76d74d8a82d9393b (patch)
tree2b45c6fceb1550baae45eaf3462a86eef0aaae47 /ptx
parent2e4cadc2ab061c61bacd43fab9a375b5492a1897 (diff)
downloadZLUDA-87cc72494ec492d60011933b76d74d8a82d9393b.tar.gz
ZLUDA-87cc72494ec492d60011933b76d74d8a82d9393b.zip
Parse Linux vectorAdd debug PTX kernel
Diffstat (limited to 'ptx')
-rw-r--r--ptx/Cargo.toml4
-rw-r--r--ptx/src/ast.rs31
-rw-r--r--ptx/src/lib.rs7
-rw-r--r--ptx/src/ptx.lalrpop94
-rw-r--r--ptx/src/test/mod.rs8
-rw-r--r--ptx/src/translate.rs21
6 files changed, 145 insertions, 20 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 9842e27..42d60cb 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -7,7 +7,7 @@ edition = "2018"
[lib]
[dependencies]
-lalrpop-util = "0.18.1"
+lalrpop-util = "0.19"
regex = "1"
rspirv = "0.6"
spirv_headers = "1.4"
@@ -16,7 +16,7 @@ bit-vec = "0.6"
half ="1.6"
[build-dependencies.lalrpop]
-version = "0.18.1"
+version = "0.19"
features = ["lexer"]
[dev-dependencies]
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index ed58d42..5de1db6 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -11,6 +11,7 @@ quick_error! {
}
SyntaxError {}
NonF32Ftz {}
+ WrongArrayType {}
}
}
@@ -50,11 +51,16 @@ pub struct Module<'a> {
pub functions: Vec<Function<'a>>,
}
+pub enum FunctionReturn<'a> {
+ Func(Vec<Argument<'a>>),
+ Kernel,
+}
+
pub struct Function<'a> {
- pub kernel: bool,
+ pub func_directive: FunctionReturn<'a>,
pub name: &'a str,
pub args: Vec<Argument<'a>>,
- pub body: Vec<Statement<ParsedArgParams<'a>>>,
+ pub body: Option<Vec<Statement<ParsedArgParams<'a>>>>,
}
#[derive(Default)]
@@ -68,6 +74,7 @@ pub struct Argument<'a> {
pub enum Type {
Scalar(ScalarType),
ExtendedScalar(ExtendedScalarType),
+ Array(ScalarType, u32),
}
impl From<FloatType> for Type {
@@ -173,10 +180,12 @@ pub enum Statement<P: ArgParams> {
Label(P::ID),
Variable(Variable<P>),
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
+ Block(Vec<Statement<P>>),
}
pub struct Variable<P: ArgParams> {
pub space: StateSpace,
+ pub align: Option<u32>,
pub v_type: Type,
pub name: P::ID,
pub count: Option<u32>,
@@ -190,6 +199,7 @@ pub enum StateSpace {
Global,
Local,
Shared,
+ Param,
}
pub struct PredAt<ID> {
@@ -211,6 +221,23 @@ pub enum Instruction<P: ArgParams> {
Shl(ShlType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
+ Call(CallData, ArgCall<P>),
+ Abs(AbsDetails, Arg2<P>),
+}
+
+pub struct CallData {
+ pub uniform: bool,
+}
+
+pub struct AbsDetails {
+ pub flush_to_zero: bool,
+ pub typ: ScalarType
+}
+
+pub struct ArgCall<P: ArgParams> {
+ pub ret_params: Vec<P::ID>,
+ pub func: P::ID,
+ pub param_list: Vec<P::ID>,
}
pub trait ArgParams {
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 6912d92..03d6d58 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -27,8 +27,11 @@ pub mod ast;
mod test;
mod translate;
-pub use ast::Module;
-pub use translate::to_spirv;
+pub use lalrpop_util::ParseError as ParseError;
+pub use lalrpop_util::lexer::Token as Token;
+pub use crate::ptx::ModuleParser as ModuleParser;
+pub use translate::to_spirv as to_spirv;
+pub use rspirv::dr::Error as SpirvError;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 66e831e..7438e97 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -24,6 +24,7 @@ match {
"|",
".acquire",
".address_size",
+ ".align",
".and",
".b16",
".b32",
@@ -108,8 +109,10 @@ match {
".xor",
} else {
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
+ "abs",
"add",
"bra",
+ "call",
"cvt",
"cvta",
"debug",
@@ -135,8 +138,10 @@ match {
}
ExtendedID : &'input str = {
+ "abs",
"add",
"bra",
+ "call",
"cvt",
"cvta",
"debug",
@@ -197,9 +202,9 @@ AddressSize = {
Function: ast::Function<'input> = {
LinkingDirective*
- <kernel:IsKernel>
+ <func_directive:FunctionReturn>
<name:ExtendedID>
- "(" <args:Comma<FunctionInput>> ")"
+ <args:Arguments>
<body:FunctionBody> => ast::Function{<>}
};
@@ -209,11 +214,15 @@ LinkingDirective = {
".weak"
};
-IsKernel: bool = {
- ".entry" => true,
- ".func" => false
+FunctionReturn: ast::FunctionReturn<'input> = {
+ ".entry" => ast::FunctionReturn::Kernel,
+ ".func" <args:Arguments?> => ast::FunctionReturn::Func(args.unwrap_or_else(|| Vec::new()))
};
+Arguments: Vec<ast::Argument<'input>> = {
+ "(" <args:Comma<FunctionInput>> ")" => args
+}
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
FunctionInput: ast::Argument<'input> = {
".param" <_type:ScalarType> <name:ExtendedID> => {
@@ -226,8 +235,9 @@ FunctionInput: ast::Argument<'input> = {
}
};
-pub(crate) FunctionBody: Vec<ast::Statement<ast::ParsedArgParams<'input>>> = {
- "{" <s:Statement*> "}" => { without_none(s) }
+pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
+ "{" <s:Statement*> "}" => { Some(without_none(s)) },
+ ";" => { None }
};
StateSpaceSpecifier: ast::StateSpace = {
@@ -236,7 +246,8 @@ StateSpaceSpecifier: ast::StateSpace = {
".const" => ast::StateSpace::Const,
".global" => ast::StateSpace::Global,
".local" => ast::StateSpace::Local,
- ".shared" => ast::StateSpace::Shared
+ ".shared" => ast::StateSpace::Shared,
+ ".param" => ast::StateSpace::Param, // used to prepare function call
};
@@ -276,7 +287,8 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
- <p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i))
+ <p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
+ "{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
};
DebugDirective: () = {
@@ -292,10 +304,32 @@ Label: &'input str = {
<id:ExtendedID> ":" => id
};
+Align: u32 = {
+ ".align" <a:Num> => {
+ let align = a.parse::<u32>();
+ align.unwrap_with(errors)
+ }
+};
+
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
- <s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
+ <s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => {
let (name, count) = v;
- ast::Variable { space: s, v_type: t, name: name, count: count }
+ let t = match (t, arr) {
+ (ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
+ (t, Some(_)) => {
+ errors.push(ast::PtxError::WrongArrayType);
+ t
+ },
+ (t, None) => t,
+ };
+ ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
+ }
+};
+
+ArraySpecifier: u32 = {
+ "[" <n:Num> "]" => {
+ let size = n.parse::<u32>();
+ size.unwrap_with(errors)
}
};
@@ -326,6 +360,8 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstSt,
InstRet,
InstCvta,
+ InstCall,
+ InstAbs,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -819,6 +855,36 @@ CvtaSize: ast::CvtaSize = {
".u64" => ast::CvtaSize::U64,
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
+InstCall: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "call" <u:".uni"?> <a:ArgCall> => ast::Instruction::Call(ast::CallData { uniform: u.is_some() }, a)
+};
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
+InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "abs" <t:SignedIntType> <a:Arg2> => {
+ ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: t }, a)
+ },
+ "abs" <f:".ftz"?> ".f32" <a:Arg2> => {
+ ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F32 }, a)
+ },
+ "abs" ".f64" <a:Arg2> => {
+ ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: ast::ScalarType::F64 }, a)
+ },
+ "abs" <f:".ftz"?> ".f16" <a:Arg2> => {
+ ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F16 }, a)
+ },
+ "abs" <f:".ftz"?> ".f16x2" <a:Arg2> => {
+ todo!()
+ },
+};
+
+SignedIntType: ast::ScalarType = {
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+};
+
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => {
@@ -873,6 +939,12 @@ Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
};
+ArgCall: ast::ArgCall<ast::ParsedArgParams<'input>> = {
+ "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{<>},
+ <func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{ret_params: Vec::new(), func, param_list},
+ <func:ExtendedID> => ast::ArgCall{ret_params: Vec::new(), func, param_list: Vec::new()},
+};
+
OptionalDst: &'input str = {
"|" <dst2:ExtendedID> => dst2
}
diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs
index f66992b..3252b50 100644
--- a/ptx/src/test/mod.rs
+++ b/ptx/src/test/mod.rs
@@ -25,3 +25,11 @@ fn operands_ptx() {
let vector_add = include_str!("operands.ptx");
parse_and_assert(vector_add);
}
+
+#[test]
+#[allow(non_snake_case)]
+fn _Z9vectorAddPKfS0_Pfi_ptx() {
+ let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
+ parse_and_assert(vector_add);
+}
+
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c0cdf01..b4d01eb 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -16,6 +16,7 @@ impl SpirvType {
let key = match t {
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
+ ast::Type::Array(_, _) => todo!(),
};
SpirvType::Pointer(key, sc)
}
@@ -26,6 +27,7 @@ impl From<ast::Type> for SpirvType {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
+ ast::Type::Array(_, _) => todo!(),
}
}
}
@@ -195,10 +197,13 @@ fn emit_function<'a>(
let func_type = get_function_type(builder, map, &f.args);
let func_id =
builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
- if f.kernel {
- builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
+ match f.func_directive {
+ ast::FunctionReturn::Kernel => {
+ builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[])
+ }
+ _ => todo!(),
}
- let (mut func_body, unique_ids) = to_ssa(&f.args, f.body);
+ let (mut func_body, unique_ids) = to_ssa(&f.args, f.body.unwrap_or_else(|| todo!()));
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
func_body = apply_id_offset(func_body, id_offset);
@@ -266,6 +271,7 @@ fn normalize_predicates(
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
+ ast::Statement::Block(_) => todo!(),
ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred {
@@ -652,6 +658,8 @@ fn emit_function_body_ops(
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
+ ast::Instruction::Abs(_, _) => todo!(),
+ ast::Instruction::Call(_,_) => todo!(),
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
@@ -1076,6 +1084,7 @@ fn expand_map_variables<'a>(
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
match s {
+ ast::Statement::Block(_) => todo!(),
ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))),
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
@@ -1086,6 +1095,7 @@ fn expand_map_variables<'a>(
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
result.push(ast::Statement::Variable(ast::Variable {
space: var.space,
+ align: var.align,
v_type: var.v_type,
name: new_id,
count: None,
@@ -1096,6 +1106,7 @@ fn expand_map_variables<'a>(
let new_id = id_defs.add_def(var.name, Some(var.v_type));
result.push(ast::Statement::Variable(ast::Variable {
space: var.space,
+ align: var.align,
v_type: var.v_type,
name: new_id,
count: None,
@@ -1307,6 +1318,8 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
+ ast::Instruction::Abs(_, _) => todo!(),
+ ast::Instruction::Call(_, _) => todo!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type))))
@@ -1432,6 +1445,8 @@ impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
+ ast::Instruction::Abs(_, _) => todo!(),
+ ast::Instruction::Call(_, _) => todo!(),
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)