aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-14 22:41:46 +0200
committerAndrzej Janik <[email protected]>2021-09-14 22:41:46 +0200
commit2cd0fcb65066cd4cddef66900593dc883743bc68 (patch)
tree80223e5b421f5c79dd15da8510d6a80b54d0ba17 /ptx
parent986fa49097ef31fcd5eedcc05a624eb57d582ba4 (diff)
downloadZLUDA-2cd0fcb65066cd4cddef66900593dc883743bc68.tar.gz
ZLUDA-2cd0fcb65066cd4cddef66900593dc883743bc68.zip
Parse and test const buffers
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ptx.lalrpop35
-rw-r--r--ptx/src/test/spirv_run/const.ptx31
-rw-r--r--ptx/src/test/spirv_run/const.spvtxt47
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
4 files changed, 96 insertions, 18 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index b20a30a..abefdf8 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -607,41 +607,33 @@ SharedVariable: ast::Variable<&'input str> = {
}
ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = {
- <linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => {
+ <linking:LinkingDirectives> <state_space:VariableStateSpace> <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
- let state_space = ast::StateSpace::Global;
(linking, ast::Variable { align, v_type, state_space, name, array_init })
},
- <linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => {
- let (align, v_type, name, array_init) = def;
- let state_space = ast::StateSpace::Shared;
- (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() })
- },
- <linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ <linking:LinkingDirectives> <space:VariableStateSpace> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- if space == ".global" {
- (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
- } else {
- (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
- }
+ (ast::Type::Array(t, dimensions), space, init)
}
ast::ArrayOrPointer::Pointer => {
if !linking.contains(ast::LinkingDirective::EXTERN) {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
- if space == ".global" {
- (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new())
- } else {
- (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new())
- }
+ (ast::Type::Array(t, Vec::new()), space, Vec::new())
}
};
Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init }))
}
}
+VariableStateSpace: ast::StateSpace = {
+ ".const" => ast::StateSpace::Const,
+ ".global" => ast::StateSpace::Global,
+ ".shared" => ast::StateSpace::Shared,
+};
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
".param" <var:VariableScalar<LdStScalarType>> => {
@@ -2096,4 +2088,11 @@ CommaNonEmpty<T>: Vec<T> = {
Or<T1, T2>: T1 = {
T1,
T2
+}
+
+#[inline]
+Or3<T1, T2, T3>: T1 = {
+ T1,
+ T2,
+ T3
} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/const.ptx b/ptx/src/test/spirv_run/const.ptx
new file mode 100644
index 0000000..c22ac2b
--- /dev/null
+++ b/ptx/src/test/spirv_run/const.ptx
@@ -0,0 +1,31 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.const .align 8 .b16 constparams[4] = { 10, 20, 30, 40 };
+
+.visible .entry const(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b16 temp1;
+ .reg .b16 temp2;
+ .reg .b16 temp3;
+ .reg .b16 temp4;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.const.b16 temp1, constparams[0];
+ ld.const.b16 temp2, constparams[1];
+ ld.const.b16 temp3, constparams[2];
+ ld.const.b16 temp4, constparams[3];
+ st.u16 [out_addr], temp1;
+ st.u16 [out_addr+2], temp2;
+ st.u16 [out_addr+4], temp3;
+ st.u16 [out_addr+6], temp4;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/const.spvtxt b/ptx/src/test/spirv_run/const.spvtxt
new file mode 100644
index 0000000..9a7f254
--- /dev/null
+++ b/ptx/src/test/spirv_run/const.spvtxt
@@ -0,0 +1,47 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %21 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "clz"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %24 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %24
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %19 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %9 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %9
+ %10 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_Generic_uint %12
+ %11 = OpLoad %uint %17 Aligned 4
+ OpStore %6 %11
+ %14 = OpLoad %uint %6
+ %13 = OpExtInst %uint %21 clz %14
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %uint %6
+ %18 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %18 %16 Aligned 4
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 51f1930..13cf0f1 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -203,6 +203,7 @@ test_ptx!(
);
test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
+test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
struct DisplayError<T: Debug> {
err: T,