aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-08-05 23:50:20 +0200
committerAndrzej Janik <[email protected]>2020-08-05 23:50:20 +0200
commitd47cd1e133995a08af15edd23c476ebf6d5cabf8 (patch)
treeeb76fa80dbd3e05161655dcbb52dce48dce166d0 /ptx
parent7b407d1c44535c2aef2c1ca0eb1fbbb58a1513d2 (diff)
downloadZLUDA-d47cd1e133995a08af15edd23c476ebf6d5cabf8.tar.gz
ZLUDA-d47cd1e133995a08af15edd23c476ebf6d5cabf8.zip
Add support for cvta and global ld/st
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs23
-rw-r--r--ptx/src/ptx.lalrpop36
-rw-r--r--ptx/src/test/spirv_run/cvta.ptx23
-rw-r--r--ptx/src/test/spirv_run/cvta.spvtxt42
-rw-r--r--ptx/src/test/spirv_run/mod.rs3
-rw-r--r--ptx/src/translate.rs43
6 files changed, 158 insertions, 12 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index a2c6d66..ed58d42 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -207,6 +207,7 @@ pub enum Instruction<P: ArgParams> {
Not(NotType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>),
+ Cvta(CvtaDetails, Arg2<P>),
Shl(ShlType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
@@ -308,7 +309,7 @@ pub enum LdScope {
Sys,
}
-#[derive(Copy, Clone, PartialEq, Eq)]
+#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum LdStateSpace {
Generic,
Const,
@@ -511,6 +512,26 @@ impl CvtDetails {
}
}
+pub struct CvtaDetails {
+ pub to: CvtaStateSpace,
+ pub from: CvtaStateSpace,
+ pub size: CvtaSize,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum CvtaStateSpace {
+ Generic,
+ Const,
+ Global,
+ Local,
+ Shared,
+}
+
+pub enum CvtaSize {
+ U32,
+ U64,
+}
+
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum ShlType {
B16,
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 5f97e6c..66e831e 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -90,6 +90,7 @@ match {
".sreg",
".sys",
".target",
+ ".to",
".u16",
".u32",
".u64",
@@ -110,6 +111,7 @@ match {
"add",
"bra",
"cvt",
+ "cvta",
"debug",
"ld",
"map_f64_to_f32",
@@ -136,6 +138,7 @@ ExtendedID : &'input str = {
"add",
"bra",
"cvt",
+ "cvta",
"debug",
"ld",
"map_f64_to_f32",
@@ -322,6 +325,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstShl,
InstSt,
InstRet,
+ InstCvta,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -783,6 +787,38 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
};
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta
+InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "cvta" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ ast::Instruction::Cvta(ast::CvtaDetails {
+ to: to,
+ from: ast::CvtaStateSpace::Generic,
+ size: s
+ },
+ a)
+ },
+ "cvta" ".to" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ ast::Instruction::Cvta(ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Generic,
+ from: from,
+ size: s
+ },
+ a)
+ }
+}
+
+CvtaStateSpace: ast::CvtaStateSpace = {
+ ".const" => ast::CvtaStateSpace::Const,
+ ".global" => ast::CvtaStateSpace::Global,
+ ".local" => ast::CvtaStateSpace::Local,
+ ".shared" => ast::CvtaStateSpace::Shared,
+}
+
+CvtaSize: ast::CvtaSize = {
+ ".u32" => ast::CvtaSize::U32,
+ ".u64" => ast::CvtaSize::U64,
+}
+
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => {
diff --git a/ptx/src/test/spirv_run/cvta.ptx b/ptx/src/test/spirv_run/cvta.ptx
new file mode 100644
index 0000000..c24c959
--- /dev/null
+++ b/ptx/src/test/spirv_run/cvta.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry cvta(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr, in_addr;
+ cvta.to.global.u64 out_addr, out_addr;
+
+ ld.global.f32 temp, [in_addr];
+ st.global.f32 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt
new file mode 100644
index 0000000..1aa7425
--- /dev/null
+++ b/ptx/src/test/spirv_run/cvta.spvtxt
@@ -0,0 +1,42 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %1 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %5 "cvta"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %4 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
+ %5 = OpFunction %void None %4
+ %6 = OpFunctionParameter %ulong
+ %7 = OpFunctionParameter %ulong
+ %21 = OpLabel
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ %10 = OpVariable %_ptr_Function_float Function
+ OpStore %8 %6
+ OpStore %9 %7
+ %12 = OpLoad %ulong %8
+ %11 = OpCopyObject %ulong %12
+ OpStore %8 %11
+ %14 = OpLoad %ulong %9
+ %13 = OpCopyObject %ulong %14
+ OpStore %9 %13
+ %16 = OpLoad %ulong %8
+ %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16
+ %15 = OpLoad %float %19
+ OpStore %10 %15
+ %17 = OpLoad %ulong %9
+ %18 = OpLoad %float %10
+ %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17
+ OpStore %20 %18
+ OpReturn
+ OpFunctionEnd
+ \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index e1e5c32..c159280 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -48,7 +48,8 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
test_ptx!(bra, [10u64], [11u64]);
test_ptx!(not, [0u64], [u64::max_value()]);
test_ptx!(shl, [11u64], [44u64]);
-test_ptx!(cvt_sat_s_u, [0i32], [0i32]);
+test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
+test_ptx!(cvta, [3.0f32], [3.0f32]);
struct DisplayError<T: Display + Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 511ef72..ebce1dd 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -671,7 +671,7 @@ fn emit_function_body_ops(
}
let result_type = map.get_or_add_scalar(builder, data.typ);
match data.state_space {
- ast::LdStateSpace::Generic => {
+ ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param => {
@@ -683,7 +683,8 @@ fn emit_function_body_ops(
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
- || data.state_space != ast::StStateSpace::Generic
+ || (data.state_space != ast::StStateSpace::Generic
+ && data.state_space != ast::StStateSpace::Global)
{
todo!()
}
@@ -729,6 +730,13 @@ fn emit_function_body_ops(
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, opencl, dets, arg)?;
}
+ ast::Instruction::Cvta(_, arg) => {
+ // This would be only meaningful if const/slm/global pointers
+ // had a different format than generic pointers, but they don't pretty much by ptx definition
+ // Honestly, I have no idea why this instruction exists and is emitted by the compiler
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ }
ast::Instruction::SetpBool(_, _) => todo!(),
},
Statement::LoadVar(arg, typ) => {
@@ -997,13 +1005,10 @@ fn emit_implicit_conversion(
_ => todo!(),
};
match cv.kind {
- ConversionKind::Ptr => {
+ ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- SpirvScalarKey::from(to_type),
- spirv_headers::StorageClass::Generic,
- ),
+ SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@@ -1365,6 +1370,10 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
+ ast::Instruction::Cvta(d, a) => {
+ let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
+ ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type)))
+ }
}
}
}
@@ -1443,6 +1452,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
+ | ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_) => None,
@@ -1498,7 +1508,7 @@ enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- Ptr,
+ Ptr(ast::LdStateSpace),
}
impl ImplicitConversion {
@@ -1944,6 +1954,19 @@ impl ast::IntType {
}
}
+impl ast::LdStateSpace {
+ fn to_spirv(self) -> spirv::StorageClass {
+ match self {
+ ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
+ ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::LdStateSpace::Local => spirv::StorageClass::Function,
+ ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::LdStateSpace::Param => unreachable!(),
+ }
+ }
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -1980,7 +2003,7 @@ fn insert_implicit_conversions_ld_src(
src,
should_convert_ld_param_src,
),
- ast::LdStateSpace::Generic => {
+ ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
ScalarKind::Byte,
@@ -1998,7 +2021,7 @@ fn insert_implicit_conversions_ld_src(
new_src,
new_src_type,
instr_type,
- ConversionKind::Ptr,
+ ConversionKind::Ptr(state_space),
)
}
_ => todo!(),