diff options
author | Andrzej Janik <[email protected]> | 2020-05-07 00:37:10 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-05-07 00:37:10 +0200 |
commit | fa075abc226b5dcfd50355ebe690c192823d7a5e (patch) | |
tree | eefabcbaf603f21a1c2c9b16c88ff7c8c2ef9153 | |
parent | 3b433456a1428a423f7f5ec8aaa3e926eb9eea99 (diff) | |
download | ZLUDA-fa075abc226b5dcfd50355ebe690c192823d7a5e.tar.gz ZLUDA-fa075abc226b5dcfd50355ebe690c192823d7a5e.zip |
Translate instruction ld
-rw-r--r-- | ptx/src/ast.rs | 52 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 86 | ||||
-rw-r--r-- | ptx/src/translate.rs | 73 |
3 files changed, 152 insertions, 59 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 82580aa..190c21a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -187,7 +187,53 @@ pub enum MovOperand<ID> { Vec(String, String), } -pub struct LdData {} +pub enum VectorPrefix { + V2, + V4 +} + +pub struct LdData { + pub qualifier: LdQualifier, + pub state_space: LdStateSpace, + pub caching: LdCacheOperator, + pub vector: Option<VectorPrefix>, + pub typ: ScalarType +} + +#[derive(PartialEq, Eq)] +pub enum LdQualifier { + Weak, + Volatile, + Relaxed(LdScope), + Acquire(LdScope), +} + +#[derive(PartialEq, Eq)] +pub enum LdScope { + Cta, + Gpu, + Sys +} + +#[derive(PartialEq, Eq)] +pub enum LdStateSpace { + Generic, + Const, + Global, + Local, + Param, + Shared, +} + + +#[derive(PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached +} pub struct MovData {} @@ -201,7 +247,9 @@ pub struct SetpBoolData {} pub struct NotData {} -pub struct BraData {} +pub struct BraData { + pub uniform: bool +} pub struct CvtData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 83a0fe2..ded2386 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -106,6 +106,16 @@ Type: ast::Type = { }; ScalarType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + MemoryType +}; + +ExtendedScalarType: ast::ExtendedScalarType = { + ".f16x2" => ast::ExtendedScalarType::F16x2, + ".pred" => ast::ExtendedScalarType::Pred, +}; + +MemoryType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, @@ -118,23 +128,10 @@ ScalarType: ast::ScalarType = { ".s16" => ast::ScalarType::S16, ".s32" => ast::ScalarType::S32, ".s64" => ast::ScalarType::S64, - ".f16" => ast::ScalarType::F16, ".f32" => ast::ScalarType::F32, ".f64" => ast::ScalarType::F64, }; -ExtendedScalarType: ast::ExtendedScalarType = { - ".f16x2" => ast::ExtendedScalarType::F16x2, - ".pred" => ast::ExtendedScalarType::Pred, -}; - -BaseType = { - ".b8", ".b16", ".b32", ".b64", - ".u8", ".u16", ".u32", ".u64", - ".s8", ".s16", ".s32", ".s64", - ".f32", ".f64" -}; - Statement: Option<ast::Statement<&'input str>> = { <l:Label> => Some(ast::Statement::Label(l)), DebugDirective => None, @@ -191,36 +188,47 @@ Instruction: ast::Instruction<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<&'input str> = { - "ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType <dst:ID> "," "[" <src:Operand> "]" => { - ast::Instruction::Ld(ast::LdData{}, ast::Arg2{dst:dst, src:src}) + "ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => { + ast::Instruction::Ld( + ast::LdData { + qualifier: q.unwrap_or(ast::LdQualifier::Weak), + state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + caching: cop.unwrap_or(ast::LdCacheOperator::Cached), + vector: v, + typ: t + }, + ast::Arg2 { dst:dst, src:src } + ) } }; -LdQualifier: () = { - ".weak", - ".volatile", - ".relaxed" LdScope, - ".acquire" LdScope, +LdQualifier: ast::LdQualifier = { + ".weak" => ast::LdQualifier::Weak, + ".volatile" => ast::LdQualifier::Volatile, + ".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s), + ".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s), }; -LdScope = { - ".cta", ".gpu", ".sys" +LdScope: ast::LdScope = { + ".cta" => ast::LdScope::Cta, + ".gpu" => ast::LdScope::Gpu, + ".sys" => ast::LdScope::Sys }; -LdStateSpace = { - ".const", - ".global", - ".local", - ".param", - ".shared", +LdStateSpace: ast::LdStateSpace = { + ".const" => ast::LdStateSpace::Const, + ".global" => ast::LdStateSpace::Global, + ".local" => ast::LdStateSpace::Local, + ".param" => ast::LdStateSpace::Param, + ".shared" => ast::LdStateSpace::Shared, }; -LdCacheOperator = { - ".ca", - ".cg", - ".cs", - ".lu", - ".cv", +LdCacheOperator: ast::LdCacheOperator = { + ".ca" => ast::LdCacheOperator::Cached, + ".cg" => ast::LdCacheOperator::L2Only, + ".cs" => ast::LdCacheOperator::Streaming, + ".lu" => ast::LdCacheOperator::LastUse, + ".cv" => ast::LdCacheOperator::Uncached, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov @@ -332,7 +340,7 @@ PredAt: ast::PredAt<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra InstBra: ast::Instruction<&'input str> = { - "bra" ".uni"? <a:Arg1> => ast::Instruction::Bra(ast::BraData{}, a) + "bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt @@ -372,7 +380,7 @@ ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st InstSt: ast::Instruction<&'input str> = { - "st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" <dst:ID> "]" "," <src:Operand> => { + "st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => { ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src}) } }; @@ -454,9 +462,9 @@ OptionalDst: &'input str = { "|" <dst2:ID> => dst2 } -Vector = { - ".v2", - ".v4" +VectorPrefix: ast::VectorPrefix = { + ".v2" => ast::VectorPrefix::V2, + ".v4" => ast::VectorPrefix::V4 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 52de35d..f5c5107 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -8,6 +8,7 @@ use std::fmt; #[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
Base(ast::ScalarType),
+ Pointer(ast::ScalarType, spirv::StorageClass),
}
struct TypeWordMap {
@@ -33,29 +34,41 @@ impl TypeWordMap { self.fn_void
}
- fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
- *self.complex.entry(t).or_insert_with(|| match t {
- SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => {
+ fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
+ *self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 => {
b.type_int(8, 0)
}
- SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => {
+ ast::ScalarType::B16 | ast::ScalarType::U16 => {
b.type_int(16, 0)
}
- SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => {
+ ast::ScalarType::B32 | ast::ScalarType::U32 => {
b.type_int(32, 0)
}
- SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => {
+ ast::ScalarType::B64 | ast::ScalarType::U64 => {
b.type_int(64, 0)
}
- SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1),
- SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1),
- SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1),
- SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1),
- SpirvType::Base(ast::ScalarType::F16) => b.type_float(16),
- SpirvType::Base(ast::ScalarType::F32) => b.type_float(32),
- SpirvType::Base(ast::ScalarType::F64) => b.type_float(64),
+ ast::ScalarType::S8 => b.type_int(8, 1),
+ ast::ScalarType::S16 => b.type_int(16, 1),
+ ast::ScalarType::S32 => b.type_int(32, 1),
+ ast::ScalarType::S64 => b.type_int(64, 1),
+ ast::ScalarType::F16 => b.type_float(16),
+ ast::ScalarType::F32 => b.type_float(32),
+ ast::ScalarType::F64 => b.type_float(64),
})
}
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
+ match t {
+ SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
+ SpirvType::Pointer(scalar, storage) => {
+ let base = self.get_or_add_scalar(b, scalar);
+ *self.complex.entry(t).or_insert_with(|| {
+ b.type_pointer(None, storage, base)
+ })
+ }
+ }
+ }
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
@@ -123,7 +136,7 @@ fn emit_function<'a>( );
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
- emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?;
+ emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
builder.end_function()?;
builder.ret()?;
builder.end_function()?;
@@ -178,6 +191,7 @@ fn collect_label_ids<'a>( fn emit_function_body_ops(
builder: &mut dr::Builder,
id_offset: spirv::Word,
+ map: &mut TypeWordMap,
func: &[Statement],
cfg: &[BasicBlock],
) -> Result<(), dr::Error> {
@@ -193,12 +207,35 @@ fn emit_function_body_ops( };
builder.begin_block(header_id)?;
for s in body {
- /*
match s {
- Statement::Instruction(pred, inst) => (),
+ // If block startd with a label it has already been emitted,
+ // all other labels in the block are unused
Statement::Label(_) => (),
+ Statement::Conditional(bra) => {
+ builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
+ }
+ Statement::Instruction(inst) => match inst {
+ // Sadly, SPIR-V does not support marking jumps as guaranteed-converged
+ ast::Instruction::Bra(_, arg) => {
+ builder.branch(arg.src)?;
+ }
+ ast::Instruction::Ld(data, arg) => {
+ if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() {
+ todo!()
+ }
+ let storage_class = match data.state_space {
+ ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
+ ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
+ _ => todo!(),
+ };
+ let result_type = map.get_or_add(builder, SpirvType::Base(data.typ));
+ let pointer_type =
+ map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class));
+ builder.load(result_type, None, pointer_type, None, [])?;
+ }
+ _ => todo!(),
+ },
}
- */
}
}
Ok(())
@@ -1273,7 +1310,7 @@ mod tests { let func = vec![
Statement::Label(12),
Statement::Instruction(ast::Instruction::Bra(
- ast::BraData {},
+ ast::BraData { uniform: false },
ast::Arg1 { src: 12 },
)),
];
|