aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-05-07 00:37:10 +0200
committerAndrzej Janik <[email protected]>2020-05-07 00:37:10 +0200
commitfa075abc226b5dcfd50355ebe690c192823d7a5e (patch)
treeeefabcbaf603f21a1c2c9b16c88ff7c8c2ef9153
parent3b433456a1428a423f7f5ec8aaa3e926eb9eea99 (diff)
downloadZLUDA-fa075abc226b5dcfd50355ebe690c192823d7a5e.tar.gz
ZLUDA-fa075abc226b5dcfd50355ebe690c192823d7a5e.zip
Translate instruction ld
-rw-r--r--ptx/src/ast.rs52
-rw-r--r--ptx/src/ptx.lalrpop86
-rw-r--r--ptx/src/translate.rs73
3 files changed, 152 insertions, 59 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 82580aa..190c21a 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -187,7 +187,53 @@ pub enum MovOperand<ID> {
Vec(String, String),
}
-pub struct LdData {}
+pub enum VectorPrefix {
+ V2,
+ V4
+}
+
+pub struct LdData {
+ pub qualifier: LdQualifier,
+ pub state_space: LdStateSpace,
+ pub caching: LdCacheOperator,
+ pub vector: Option<VectorPrefix>,
+ pub typ: ScalarType
+}
+
+#[derive(PartialEq, Eq)]
+pub enum LdQualifier {
+ Weak,
+ Volatile,
+ Relaxed(LdScope),
+ Acquire(LdScope),
+}
+
+#[derive(PartialEq, Eq)]
+pub enum LdScope {
+ Cta,
+ Gpu,
+ Sys
+}
+
+#[derive(PartialEq, Eq)]
+pub enum LdStateSpace {
+ Generic,
+ Const,
+ Global,
+ Local,
+ Param,
+ Shared,
+}
+
+
+#[derive(PartialEq, Eq)]
+pub enum LdCacheOperator {
+ Cached,
+ L2Only,
+ Streaming,
+ LastUse,
+ Uncached
+}
pub struct MovData {}
@@ -201,7 +247,9 @@ pub struct SetpBoolData {}
pub struct NotData {}
-pub struct BraData {}
+pub struct BraData {
+ pub uniform: bool
+}
pub struct CvtData {}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 83a0fe2..ded2386 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -106,6 +106,16 @@ Type: ast::Type = {
};
ScalarType: ast::ScalarType = {
+ ".f16" => ast::ScalarType::F16,
+ MemoryType
+};
+
+ExtendedScalarType: ast::ExtendedScalarType = {
+ ".f16x2" => ast::ExtendedScalarType::F16x2,
+ ".pred" => ast::ExtendedScalarType::Pred,
+};
+
+MemoryType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
@@ -118,23 +128,10 @@ ScalarType: ast::ScalarType = {
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
- ".f16" => ast::ScalarType::F16,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
-ExtendedScalarType: ast::ExtendedScalarType = {
- ".f16x2" => ast::ExtendedScalarType::F16x2,
- ".pred" => ast::ExtendedScalarType::Pred,
-};
-
-BaseType = {
- ".b8", ".b16", ".b32", ".b64",
- ".u8", ".u16", ".u32", ".u64",
- ".s8", ".s16", ".s32", ".s64",
- ".f32", ".f64"
-};
-
Statement: Option<ast::Statement<&'input str>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
@@ -191,36 +188,47 @@ Instruction: ast::Instruction<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<&'input str> = {
- "ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType <dst:ID> "," "[" <src:Operand> "]" => {
- ast::Instruction::Ld(ast::LdData{}, ast::Arg2{dst:dst, src:src})
+ "ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
+ ast::Instruction::Ld(
+ ast::LdData {
+ qualifier: q.unwrap_or(ast::LdQualifier::Weak),
+ state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
+ caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
+ vector: v,
+ typ: t
+ },
+ ast::Arg2 { dst:dst, src:src }
+ )
}
};
-LdQualifier: () = {
- ".weak",
- ".volatile",
- ".relaxed" LdScope,
- ".acquire" LdScope,
+LdQualifier: ast::LdQualifier = {
+ ".weak" => ast::LdQualifier::Weak,
+ ".volatile" => ast::LdQualifier::Volatile,
+ ".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s),
+ ".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s),
};
-LdScope = {
- ".cta", ".gpu", ".sys"
+LdScope: ast::LdScope = {
+ ".cta" => ast::LdScope::Cta,
+ ".gpu" => ast::LdScope::Gpu,
+ ".sys" => ast::LdScope::Sys
};
-LdStateSpace = {
- ".const",
- ".global",
- ".local",
- ".param",
- ".shared",
+LdStateSpace: ast::LdStateSpace = {
+ ".const" => ast::LdStateSpace::Const,
+ ".global" => ast::LdStateSpace::Global,
+ ".local" => ast::LdStateSpace::Local,
+ ".param" => ast::LdStateSpace::Param,
+ ".shared" => ast::LdStateSpace::Shared,
};
-LdCacheOperator = {
- ".ca",
- ".cg",
- ".cs",
- ".lu",
- ".cv",
+LdCacheOperator: ast::LdCacheOperator = {
+ ".ca" => ast::LdCacheOperator::Cached,
+ ".cg" => ast::LdCacheOperator::L2Only,
+ ".cs" => ast::LdCacheOperator::Streaming,
+ ".lu" => ast::LdCacheOperator::LastUse,
+ ".cv" => ast::LdCacheOperator::Uncached,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
@@ -332,7 +340,7 @@ PredAt: ast::PredAt<&'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra: ast::Instruction<&'input str> = {
- "bra" ".uni"? <a:Arg1> => ast::Instruction::Bra(ast::BraData{}, a)
+ "bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
@@ -372,7 +380,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
InstSt: ast::Instruction<&'input str> = {
- "st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" <dst:ID> "]" "," <src:Operand> => {
+ "st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => {
ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src})
}
};
@@ -454,9 +462,9 @@ OptionalDst: &'input str = {
"|" <dst2:ID> => dst2
}
-Vector = {
- ".v2",
- ".v4"
+VectorPrefix: ast::VectorPrefix = {
+ ".v2" => ast::VectorPrefix::V2,
+ ".v4" => ast::VectorPrefix::V4
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 52de35d..f5c5107 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -8,6 +8,7 @@ use std::fmt;
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
Base(ast::ScalarType),
+ Pointer(ast::ScalarType, spirv::StorageClass),
}
struct TypeWordMap {
@@ -33,29 +34,41 @@ impl TypeWordMap {
self.fn_void
}
- fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
- *self.complex.entry(t).or_insert_with(|| match t {
- SpirvType::Base(ast::ScalarType::B8) | SpirvType::Base(ast::ScalarType::U8) => {
+ fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
+ *self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 => {
b.type_int(8, 0)
}
- SpirvType::Base(ast::ScalarType::B16) | SpirvType::Base(ast::ScalarType::U16) => {
+ ast::ScalarType::B16 | ast::ScalarType::U16 => {
b.type_int(16, 0)
}
- SpirvType::Base(ast::ScalarType::B32) | SpirvType::Base(ast::ScalarType::U32) => {
+ ast::ScalarType::B32 | ast::ScalarType::U32 => {
b.type_int(32, 0)
}
- SpirvType::Base(ast::ScalarType::B64) | SpirvType::Base(ast::ScalarType::U64) => {
+ ast::ScalarType::B64 | ast::ScalarType::U64 => {
b.type_int(64, 0)
}
- SpirvType::Base(ast::ScalarType::S8) => b.type_int(8, 1),
- SpirvType::Base(ast::ScalarType::S16) => b.type_int(16, 1),
- SpirvType::Base(ast::ScalarType::S32) => b.type_int(32, 1),
- SpirvType::Base(ast::ScalarType::S64) => b.type_int(64, 1),
- SpirvType::Base(ast::ScalarType::F16) => b.type_float(16),
- SpirvType::Base(ast::ScalarType::F32) => b.type_float(32),
- SpirvType::Base(ast::ScalarType::F64) => b.type_float(64),
+ ast::ScalarType::S8 => b.type_int(8, 1),
+ ast::ScalarType::S16 => b.type_int(16, 1),
+ ast::ScalarType::S32 => b.type_int(32, 1),
+ ast::ScalarType::S64 => b.type_int(64, 1),
+ ast::ScalarType::F16 => b.type_float(16),
+ ast::ScalarType::F32 => b.type_float(32),
+ ast::ScalarType::F64 => b.type_float(64),
})
}
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
+ match t {
+ SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
+ SpirvType::Pointer(scalar, storage) => {
+ let base = self.get_or_add_scalar(b, scalar);
+ *self.complex.entry(t).or_insert_with(|| {
+ b.type_pointer(None, storage, base)
+ })
+ }
+ }
+ }
}
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
@@ -123,7 +136,7 @@ fn emit_function<'a>(
);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
- emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?;
+ emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
builder.end_function()?;
builder.ret()?;
builder.end_function()?;
@@ -178,6 +191,7 @@ fn collect_label_ids<'a>(
fn emit_function_body_ops(
builder: &mut dr::Builder,
id_offset: spirv::Word,
+ map: &mut TypeWordMap,
func: &[Statement],
cfg: &[BasicBlock],
) -> Result<(), dr::Error> {
@@ -193,12 +207,35 @@ fn emit_function_body_ops(
};
builder.begin_block(header_id)?;
for s in body {
- /*
match s {
- Statement::Instruction(pred, inst) => (),
+ // If block startd with a label it has already been emitted,
+ // all other labels in the block are unused
Statement::Label(_) => (),
+ Statement::Conditional(bra) => {
+ builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
+ }
+ Statement::Instruction(inst) => match inst {
+ // Sadly, SPIR-V does not support marking jumps as guaranteed-converged
+ ast::Instruction::Bra(_, arg) => {
+ builder.branch(arg.src)?;
+ }
+ ast::Instruction::Ld(data, arg) => {
+ if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() {
+ todo!()
+ }
+ let storage_class = match data.state_space {
+ ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
+ ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
+ _ => todo!(),
+ };
+ let result_type = map.get_or_add(builder, SpirvType::Base(data.typ));
+ let pointer_type =
+ map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class));
+ builder.load(result_type, None, pointer_type, None, [])?;
+ }
+ _ => todo!(),
+ },
}
- */
}
}
Ok(())
@@ -1273,7 +1310,7 @@ mod tests {
let func = vec![
Statement::Label(12),
Statement::Instruction(ast::Instruction::Bra(
- ast::BraData {},
+ ast::BraData { uniform: false },
ast::Arg1 { src: 12 },
)),
];