diff options
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/ntid.spvtxt | 40 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector4.ptx | 22 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector4.spvtxt | 99 | ||||
-rw-r--r-- | ptx/src/translate.rs | 162 |
5 files changed, 243 insertions, 81 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 226043f..d5bc8dd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
+test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]);
diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index 7b5a630..e5f343c 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -7,24 +7,27 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" + %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize - OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + OpEntryPoint Kernel %1 "ntid" + OpExecutionMode %1 ContractionOff + OpDecorate %24 LinkageAttributes "get_local_size" Import %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %v3ulong = OpTypeVector %ulong 3 -%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong -%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 + %35 = OpTypeFunction %ulong %uint + %36 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %33 + %uint_0 = OpConstant %uint 0 + %24 = OpFunction %ulong None %35 + %26 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %36 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong - %26 = OpLabel + %29 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -38,13 +41,12 @@ %12 = OpLoad %ulong %3 Aligned 8 OpStore %5 %12 %14 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %24 Aligned 4 + %27 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %27 Aligned 4 OpStore %6 %13 - %38 = OpLoad %v3ulong %gl_WorkGroupSize - %23 = OpCompositeExtract %ulong %38 0 - %39 = OpBitcast %ulong %23 - %16 = OpUConvert %uint %39 + %23 = OpFunctionCall %ulong %24 %uint_0 + %40 = OpBitcast %ulong %23 + %16 = OpUConvert %uint %40 %15 = OpCopyObject %uint %16 OpStore %7 %15 %18 = OpLoad %uint %6 @@ -53,7 +55,7 @@ OpStore %6 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %20 - OpStore %25 %21 Aligned 4 + %28 = OpConvertUToPtr %_ptr_Generic_uint %20 + OpStore %28 %21 Aligned 4 OpReturn - OpFunctionEnd + OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/vector4.ptx b/ptx/src/test/spirv_run/vector4.ptx new file mode 100644 index 0000000..d010b70 --- /dev/null +++ b/ptx/src/test/spirv_run/vector4.ptx @@ -0,0 +1,22 @@ +.version 6.5
+.target sm_60
+.address_size 64
+
+.visible .entry vector4(
+ .param .u64 input_p,
+ .param .u64 output_p
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .v4 .u32 temp;
+ .reg .u32 temp_scalar;
+
+ ld.param.u64 in_addr, [input_p];
+ ld.param.u64 out_addr, [output_p];
+
+ ld.v4.u32 temp, [in_addr];
+ mov.b32 temp_scalar, temp.w;
+ st.u32 [out_addr], temp_scalar;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/vector4.spvtxt b/ptx/src/test/spirv_run/vector4.spvtxt new file mode 100644 index 0000000..8253bf9 --- /dev/null +++ b/ptx/src/test/spirv_run/vector4.spvtxt @@ -0,0 +1,99 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %51 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %25 "vector" + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %55 = OpTypeFunction %v2uint %v2uint +%_ptr_Function_v2uint = OpTypePointer Function %v2uint +%_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %ulong = OpTypeInt 64 0 + %67 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint + %1 = OpFunction %v2uint None %55 + %7 = OpFunctionParameter %v2uint + %24 = OpLabel + %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function + %4 = OpVariable %_ptr_Function_v2uint Function + %5 = OpVariable %_ptr_Function_uint Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %3 %7 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 + OpStore %5 %8 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 + OpStore %6 %10 + %13 = OpLoad %uint %5 + %14 = OpLoad %uint %6 + %12 = OpIAdd %uint %13 %14 + OpStore %6 %12 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 + %22 = OpLoad %v2uint %4 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 + OpFunctionEnd + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function + %33 = OpVariable %_ptr_Function_ulong Function + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 Aligned 8 + OpStore %28 %36 + %37 = OpLoad %ulong %27 Aligned 8 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 Aligned 8 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 5fea075..6c2c594 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
- emit_builtins(&mut builder, &mut map, &id_defs);
+ //emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives(
@@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>( &mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements =
+ fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>( }
fn fix_special_registers(
+ ptx_impl_imports: &mut HashMap<String, Directive>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1276,7 +1278,6 @@ fn fix_special_registers( for s in typed_statements {
match s {
Statement::LoadVar(
- mut
details
@
LoadVarDetails {
@@ -1285,48 +1286,53 @@ fn fix_special_registers( },
) => {
let index = details.member_index.unwrap().0;
- if index == 3 {
- result.push(Statement::Constant(ConstantDefinition {
- dst: details.arg.dst,
- typ: ast::ScalarType::U32,
- value: ast::ImmediateValue::U64(0),
- }));
- } else {
- let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
- {
- Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
- None => None,
- };
- let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
- Some(sreg_and_type) => sreg_and_type,
- None => {
- result.push(Statement::LoadVar(details));
- continue;
- }
- };
- let temp_id = numeric_id_defs
- .register_intermediate(Some((details.typ.clone(), details.state_space)));
- let real_dst = details.arg.dst;
- details.arg.dst = temp_id;
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- src: sreg_src,
- dst: temp_id,
- },
- state_space: ast::StateSpace::Sreg,
- typ: ast::Type::Scalar(scalar_typ),
- member_index: Some((index, Some(vector_width))),
- }));
- result.push(Statement::Conversion(ImplicitConversion {
- src: temp_id,
- dst: real_dst,
- from_type: ast::Type::Scalar(scalar_typ),
- from_space: ast::StateSpace::Sreg,
- to_type: ast::Type::Scalar(ast::ScalarType::U32),
- to_space: ast::StateSpace::Sreg,
- kind: ConversionKind::Default,
- }));
- }
+ let sreg = numeric_id_defs
+ .special_registers
+ .get(details.arg.src)
+ .ok_or_else(|| error_unreachable())?;
+ let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
+ let index_constant = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::Constant(ConstantDefinition {
+ dst: index_constant,
+ typ: ast::ScalarType::U32,
+ value: ast::ImmediateValue::U64(index as u64),
+ }));
+ let fn_result = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ocl_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments =
+ vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
+ let input_arguments = vec![(
+ TypedOperand::Reg(index_constant),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ numeric_id_defs,
+ ptx_impl_imports,
+ ocl_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ result.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ return_arguments,
+ name: fn_call,
+ input_arguments,
+ }));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: fn_result,
+ dst: details.arg.dst,
+ from_type: ast::Type::Scalar(ocl_type),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::Default,
+ }));
}
s => result.push(s),
}
@@ -1721,8 +1727,8 @@ fn instruction_to_fn_call( id_defs,
ptx_impl_imports,
fn_name,
- return_arguments,
- input_arguments,
+ return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
)?;
Ok(Statement::Call(ResolvedCall {
uniform: false,
@@ -1732,12 +1738,12 @@ fn instruction_to_fn_call( }))
}
-fn register_external_fn_call(
+fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
- return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
- input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
@@ -1770,19 +1776,18 @@ fn register_external_fn_call( }
}
-fn fn_arguments_to_variables(
+fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
- args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<spirv::Word>> {
- args.iter()
- .map(|(_, typ, space)| ast::Variable {
- align: None,
- v_type: typ.clone(),
- state_space: *space,
- name: id_defs.register_intermediate(None),
- array_init: Vec::new(),
- })
- .collect::<Vec<_>>()
+ args.map(|(typ, space)| ast::Variable {
+ align: None,
+ v_type: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
}
fn arguments_to_resolved_arguments(
@@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>( Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Constant(_) => return Err(error_unreachable()),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
}
}
Ok(result)
@@ -4686,6 +4691,19 @@ impl PtxSpecialRegister { }
}
+ fn get_scalar_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Tid64
+ | PtxSpecialRegister::Ntid64
+ | PtxSpecialRegister::Ctaid64
+ | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
+ }
+ }
+
fn get_builtin(self) -> spirv::BuiltIn {
match self {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
@@ -4701,6 +4719,23 @@ impl PtxSpecialRegister { }
}
+ fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
+ match self {
+ PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
+ ("_Z12get_local_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
+ ("_Z14get_local_sizej", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
+ ("_Z12get_group_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
+ ("_Z14get_num_groupsj", ast::ScalarType::U64)
+ }
+ }
+ }
+
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
match self {
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
@@ -4743,6 +4778,8 @@ impl SpecialRegistersMap { }
fn interface(&self) -> Vec<spirv::Word> {
+ return Vec::new();
+ /*
self.reg_to_id
.iter()
.filter_map(|(sreg, id)| {
@@ -4753,6 +4790,7 @@ impl SpecialRegistersMap { }
})
.collect::<Vec<_>>()
+ */
}
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
|