diff options
author | Andrzej Janik <[email protected]> | 2020-09-27 13:14:19 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-09-27 13:14:19 +0200 |
commit | e0190fcbe19e9554ccc2fb0d72685569823224ef (patch) | |
tree | c396a59b3080c0bdfbf308742e4f53caf48b5030 /ptx | |
parent | 42bcd999eb2caec0046aa76d12ec7e73919495fc (diff) | |
download | ZLUDA-e0190fcbe19e9554ccc2fb0d72685569823224ef.tar.gz ZLUDA-e0190fcbe19e9554ccc2fb0d72685569823224ef.zip |
Add missing support for Milestone 1
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/ast.rs | 18 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 18 | ||||
-rw-r--r-- | ptx/src/test/mod.rs | 7 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mad_s32.ptx | 28 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mad_s32.spvtxt | 77 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 10 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mul_wide.ptx | 24 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mul_wide.spvtxt | 64 | ||||
-rw-r--r-- | ptx/src/test/vectorAdd_11.ptx | 55 | ||||
-rw-r--r-- | ptx/src/translate.rs | 210 |
10 files changed, 488 insertions, 23 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 77afee6..acefdc1 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -320,7 +320,7 @@ pub enum Instruction<P: ArgParams> { MovVector(MovVectorDetails, Arg2Vec<P>), Mul(MulDetails, Arg3<P>), Add(AddDetails, Arg3<P>), - Setp(SetpData, Arg4<P>), + Setp(SetpData, Arg4Setp<P>), SetpBool(SetpBoolData, Arg5<P>), Not(NotType, Arg2<P>), Bra(BraData, Arg1<P>), @@ -331,9 +331,13 @@ pub enum Instruction<P: ArgParams> { Ret(RetData), Call(CallInst<P>), Abs(AbsDetails, Arg2<P>), + Mad(MulDetails, Arg4<P>), } #[derive(Copy, Clone)] +pub struct MadFloatDesc {} + +#[derive(Copy, Clone)] pub struct MovVectorDetails { pub typ: MovVectorType, pub length: u8, @@ -398,6 +402,13 @@ pub struct Arg3<P: ArgParams> { } pub struct Arg4<P: ArgParams> { + pub dst: P::ID, + pub src1: P::Operand, + pub src2: P::Operand, + pub src3: P::Operand, +} + +pub struct Arg4Setp<P: ArgParams> { pub dst1: P::ID, pub dst2: Option<P::ID>, pub src1: P::Operand, @@ -503,7 +514,7 @@ sub_scalar_type!(MovVectorType { pub struct MovDetails { pub typ: MovType, - pub src_is_address: bool + pub src_is_address: bool, } sub_type! { @@ -518,17 +529,20 @@ pub enum MulDetails { Float(MulFloatDesc), } +#[derive(Copy, Clone)] pub struct MulIntDesc { pub typ: IntType, pub control: MulIntControl, } +#[derive(Copy, Clone)] pub enum MulIntControl { Low, High, Wide, } +#[derive(Copy, Clone)] pub struct MulFloatDesc { pub typ: FloatType, pub rounding: Option<RoundingMode>, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 208e076..50a6aeb 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -122,6 +122,7 @@ match { "cvta", "debug", "ld", + "mad", "map_f64_to_f32", "mov", "mul", @@ -149,6 +150,7 @@ ExtendedID : &'input str = { "cvta", "debug", "ld", + "mad", "map_f64_to_f32", "mov", "mul", @@ -442,6 +444,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstCvta, InstCall, InstAbs, + InstMad }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -649,7 +652,7 @@ InstAddMode: ast::AddDetails = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp // TODO: support f16 setp InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = { - "setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a), + "setp" <d:SetpMode> <a:Arg4Setp> => ast::Instruction::Setp(d, a), "setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a), }; @@ -995,6 +998,13 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = { }, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad +InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = { + "mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a), + "mad" ".hi" ".sat" ".s32" => todo!() +}; + SignedIntType: ast::ScalarType = { ".s16" => ast::ScalarType::S16, ".s32" => ast::ScalarType::S32, @@ -1056,7 +1066,11 @@ Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = { }; Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = { - <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>} + <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} +}; + +Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = { + <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4Setp{<>} }; // TODO: pass src3 negation somewhere diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index d251884..0339141 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -40,3 +40,10 @@ fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), TranslateError> { let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx"); compile_and_assert(vector_add) } + +#[test] +#[allow(non_snake_case)] +fn vectorAdd_11_ptx() -> Result<(), TranslateError> { + let vector_add = include_str!("vectorAdd_11.ptx"); + compile_and_assert(vector_add) +} diff --git a/ptx/src/test/spirv_run/mad_s32.ptx b/ptx/src/test/spirv_run/mad_s32.ptx new file mode 100644 index 0000000..a864266 --- /dev/null +++ b/ptx/src/test/spirv_run/mad_s32.ptx @@ -0,0 +1,28 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mad_s32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 dst; + .reg .s32 src1; + .reg .s32 src2; + .reg .s32 src3; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 src1, [in_addr]; + ld.s32 src2, [in_addr+4]; + ld.s32 src3, [in_addr+8]; + mad.lo.s32 dst, src1, src2, src3; + st.s32 [out_addr], dst; + st.s32 [out_addr+4], dst; + st.s32 [out_addr+8], dst; + ret; +} diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt new file mode 100644 index 0000000..3a7153d --- /dev/null +++ b/ptx/src/test/spirv_run/mad_s32.spvtxt @@ -0,0 +1,77 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + OpCapability Float64 + %48 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mad_s32" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %51 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %ulong_8 = OpConstant %ulong 8 + %ulong_4_0 = OpConstant %ulong 4 + %ulong_8_0 = OpConstant %ulong 8 + %1 = OpFunction %void None %51 + %10 = OpFunctionParameter %ulong + %11 = OpFunctionParameter %ulong + %46 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + %9 = OpVariable %_ptr_Function_uint Function + OpStore %2 %10 + OpStore %3 %11 + %13 = OpLoad %ulong %2 + %12 = OpCopyObject %ulong %13 + OpStore %4 %12 + %15 = OpLoad %ulong %3 + %14 = OpCopyObject %ulong %15 + OpStore %5 %14 + %17 = OpLoad %ulong %4 + %40 = OpConvertUToPtr %_ptr_Generic_uint %17 + %16 = OpLoad %uint %40 + OpStore %7 %16 + %19 = OpLoad %ulong %4 + %33 = OpIAdd %ulong %19 %ulong_4 + %41 = OpConvertUToPtr %_ptr_Generic_uint %33 + %18 = OpLoad %uint %41 + OpStore %8 %18 + %21 = OpLoad %ulong %4 + %35 = OpIAdd %ulong %21 %ulong_8 + %42 = OpConvertUToPtr %_ptr_Generic_uint %35 + %20 = OpLoad %uint %42 + OpStore %9 %20 + %23 = OpLoad %uint %7 + %24 = OpLoad %uint %8 + %25 = OpLoad %uint %9 + %56 = OpIMul %uint %23 %24 + %22 = OpIAdd %uint %25 %56 + OpStore %6 %22 + %26 = OpLoad %ulong %5 + %27 = OpLoad %uint %6 + %43 = OpConvertUToPtr %_ptr_Generic_uint %26 + OpStore %43 %27 + %28 = OpLoad %ulong %5 + %29 = OpLoad %uint %6 + %37 = OpIAdd %ulong %28 %ulong_4_0 + %44 = OpConvertUToPtr %_ptr_Generic_uint %37 + OpStore %44 %29 + %30 = OpLoad %ulong %5 + %31 = OpLoad %uint %6 + %39 = OpIAdd %ulong %30 %ulong_8_0 + %45 = OpConvertUToPtr %_ptr_Generic_uint %39 + OpStore %45 %31 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 78c3375..27dc063 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -8,7 +8,6 @@ use spirv_headers::Word; use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
};
-use std::{collections::hash_map::Entry, cmp};
use std::error;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
@@ -17,6 +16,7 @@ use std::hash::Hash; use std::mem;
use std::slice;
use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str};
+use std::{cmp, collections::hash_map::Entry};
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
@@ -65,6 +65,8 @@ test_ptx!(mov_address, [0xDEADu64], [0u64]); test_ptx!(b64tof64, [111u64], [111u64]);
test_ptx!(implicit_param, [34u32], [34u32]);
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
+test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
+test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
struct DisplayError<T: Debug> {
err: T,
@@ -93,7 +95,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>( let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
assert!(errors.len() == 0);
- let spirv = translate::to_spirv(ast)?;
+ let (spirv, _) = translate::to_spirv(ast)?;
let name = CString::new(name)?;
let result =
run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?;
@@ -127,7 +129,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?;
- let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
+ let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?;
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
@@ -157,7 +159,7 @@ fn test_spvtxt_assert<'a>( let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
- let ptx_mod = translate::to_spirv_module(ast)?;
+ let (ptx_mod, _) = translate::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
diff --git a/ptx/src/test/spirv_run/mul_wide.ptx b/ptx/src/test/spirv_run/mul_wide.ptx new file mode 100644 index 0000000..2d6f8a5 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_wide.ptx @@ -0,0 +1,24 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry mul_wide(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 inp1;
+ .reg .s32 inp2;
+ .reg .s64 result;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.s32 inp1, [in_addr];
+ ld.global.s32 inp2, [in_addr+4];
+ mul.wide.s32 result, inp1, inp2;
+ st.u64 [out_addr], result;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt new file mode 100644 index 0000000..274612c --- /dev/null +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -0,0 +1,64 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + OpCapability Float64 + %32 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_wide" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %35 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %ulong_4 = OpConstant %ulong 4 + %_struct_40 = OpTypeStruct %uint %uint + %v2uint = OpTypeVector %uint 2 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %35 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %30 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %9 + OpStore %3 %10 + %12 = OpLoad %ulong %2 + %11 = OpCopyObject %ulong %12 + OpStore %4 %11 + %14 = OpLoad %ulong %3 + %13 = OpCopyObject %ulong %14 + OpStore %5 %13 + %16 = OpLoad %ulong %4 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %15 = OpLoad %uint %26 + OpStore %6 %15 + %18 = OpLoad %ulong %4 + %25 = OpIAdd %ulong %18 %ulong_4 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %25 + %17 = OpLoad %uint %27 + OpStore %7 %17 + %20 = OpLoad %uint %6 + %21 = OpLoad %uint %7 + %41 = OpSMulExtended %_struct_40 %20 %21 + %42 = OpCompositeExtract %uint %41 0 + %43 = OpCompositeExtract %uint %41 1 + %45 = OpCompositeConstruct %v2uint %42 %43 + %19 = OpBitcast %ulong %45 + OpStore %8 %19 + %22 = OpLoad %ulong %5 + %23 = OpLoad %ulong %8 + %28 = OpCopyObject %ulong %23 + %29 = OpConvertUToPtr %_ptr_Generic_ulong %22 + OpStore %29 %28 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/vectorAdd_11.ptx b/ptx/src/test/vectorAdd_11.ptx new file mode 100644 index 0000000..ba0381e --- /dev/null +++ b/ptx/src/test/vectorAdd_11.ptx @@ -0,0 +1,55 @@ + + + + + + + + +.version 7.0 +.target sm_80 +.address_size 64 + + + +.visible .entry _Z9vectorAddPKfS0_Pfi( +.param .u64 _Z9vectorAddPKfS0_Pfi_param_0, +.param .u64 _Z9vectorAddPKfS0_Pfi_param_1, +.param .u64 _Z9vectorAddPKfS0_Pfi_param_2, +.param .u32 _Z9vectorAddPKfS0_Pfi_param_3 +) +{ +.reg .pred %p<2>; +.reg .f32 %f<4>; +.reg .b32 %r<6>; +.reg .b64 %rd<11>; + + +ld.param.u64 %rd1, [_Z9vectorAddPKfS0_Pfi_param_0]; +ld.param.u64 %rd2, [_Z9vectorAddPKfS0_Pfi_param_1]; +ld.param.u64 %rd3, [_Z9vectorAddPKfS0_Pfi_param_2]; +ld.param.u32 %r2, [_Z9vectorAddPKfS0_Pfi_param_3]; +mov.u32 %r3, %ntid.x; +mov.u32 %r4, %ctaid.x; +mov.u32 %r5, %tid.x; +mad.lo.s32 %r1, %r4, %r3, %r5; +setp.ge.s32 %p1, %r1, %r2; +@%p1 bra BB0_2; + +cvta.to.global.u64 %rd4, %rd1; +mul.wide.s32 %rd5, %r1, 4; +add.s64 %rd6, %rd4, %rd5; +cvta.to.global.u64 %rd7, %rd2; +add.s64 %rd8, %rd7, %rd5; +ld.global.f32 %f1, [%rd8]; +ld.global.f32 %f2, [%rd6]; +add.f32 %f3, %f2, %f1; +cvta.to.global.u64 %rd9, %rd3; +add.s64 %rd10, %rd9, %rd5; +st.global.f32 [%rd10], %f3; + +BB0_2: +ret; +} + + diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 5b03f0b..a1d4b6a 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -28,6 +28,7 @@ enum SpirvType { Array(SpirvScalarKey, u32),
Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
+ Struct(Vec<SpirvScalarKey>),
}
impl SpirvType {
@@ -174,6 +175,16 @@ impl TypeWordMap { .entry(t)
.or_insert_with(|| b.type_function(out_t, in_t))
}
+ SpirvType::Struct(ref underlying) => {
+ let underlying_ids = underlying
+ .iter()
+ .map(|t| self.get_or_add_spirv_scalar(b, *t))
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| b.type_struct(underlying_ids))
+ }
}
}
@@ -201,7 +212,9 @@ impl TypeWordMap { }
}
-pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, TranslateError> {
+pub fn to_spirv_module<'a>(
+ ast: ast::Module<'a>,
+) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
let ssa_functions = ast
.functions
@@ -218,17 +231,24 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
+ let mut args_len = HashMap::new();
for f in ssa_functions {
let f_body = match f.body {
Some(f) => f,
None => continue,
};
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
- emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
+ emit_function_header(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ f.func_directive,
+ &mut args_len,
+ )?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?;
}
- Ok(builder.module())
+ Ok((builder.module(), args_len))
}
fn emit_builtins(
@@ -263,7 +283,12 @@ fn emit_function_header<'a>( map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>,
func_directive: ast::MethodDecl<ExpandedArgParams>,
+ all_args_lens: &mut HashMap<String, Vec<usize>>,
) -> Result<(), TranslateError> {
+ if let ast::MethodDecl::Kernel(name, args) = &func_directive {
+ let args_lens = args.iter().map(|param| param.v_type.width()).collect();
+ all_args_lens.insert(name.to_string(), args_lens);
+ }
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
@@ -297,9 +322,11 @@ fn emit_function_header<'a>( Ok(())
}
-pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, TranslateError> {
- let module = to_spirv_module(ast)?;
- Ok(module.assemble())
+pub fn to_spirv<'a>(
+ ast: ast::Module<'a>,
+) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
+ let (module, all_args_lens) = to_spirv_module(ast)?;
+ Ok((module.assemble(), all_args_lens))
}
fn emit_capabilities(builder: &mut dr::Builder) {
@@ -905,7 +932,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> ArgumentSemantics::PhysicalPointer => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(typ);
+ let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
@@ -1314,8 +1341,8 @@ fn emit_function_body_ops( let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
let const_true = builder.constant_true(type_pred);
let const_false = builder.constant_false(type_pred);
- builder.select(result_type, result_id, operand, const_false, const_true)
- },
+ builder.select(result_type, result_id, operand, const_false, const_true)
+ }
_ => builder.not(result_type, result_id, operand),
}?;
}
@@ -1359,6 +1386,12 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(*dst), *src)?;
}
},
+ ast::Instruction::Mad(mad, arg) => match mad {
+ ast::MulDetails::Int(ref desc) => {
+ emit_mad_int(builder, map, opencl, desc, arg)?
+ }
+ ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
+ },
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1385,6 +1418,47 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_mad_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MulIntDesc,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ match desc.control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
+ builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ let cl_op = if desc.typ.is_signed() {
+ spirv::CLOp::s_mad_hi
+ } else {
+ spirv::CLOp::u_mad_hi
+ };
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ cl_op as spirv::Word,
+ [arg.src1, arg.src2, arg.src3],
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::MulFloatDesc,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ todo!()
+}
+
fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1529,7 +1603,7 @@ fn emit_setp( builder: &mut dr::Builder,
map: &mut TypeWordMap,
setp: &ast::SetpData,
- arg: &ast::Arg4<ExpandedArgParams>,
+ arg: &ast::Arg4Setp<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if setp.flush_to_zero {
todo!()
@@ -1607,6 +1681,7 @@ fn emit_mul_int( desc: &ast::MulIntDesc,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
+ let instruction_type = ast::ScalarType::from(desc.typ);
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
match desc.control {
ast::MulIntControl::Low => {
@@ -1626,11 +1701,53 @@ fn emit_mul_int( [arg.src1, arg.src2],
)?;
}
- ast::MulIntControl::Wide => todo!(),
+ ast::MulIntControl::Wide => {
+ let mul_ext_type = SpirvType::Struct(vec![
+ SpirvScalarKey::from(instruction_type),
+ SpirvScalarKey::from(instruction_type),
+ ]);
+ let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
+ let mul = if desc.typ.is_signed() {
+ builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
+ } else {
+ builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
+ };
+ let instr_width = instruction_type.width();
+ let instr_kind = instruction_type.kind();
+ let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
+ let dst_type_id = map.get_or_add_scalar(builder, dst_type);
+ struct2_bitcast_to_wide(
+ builder,
+ map,
+ SpirvScalarKey::from(instruction_type),
+ inst_type,
+ arg.dst,
+ dst_type_id,
+ mul,
+ )?;
+ }
}
Ok(())
}
+// Surprisingly, structs can't be bitcast, so we route everything through a vector
+fn struct2_bitcast_to_wide(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ base_type_key: SpirvScalarKey,
+ instruction_type: spirv::Word,
+ dst: spirv::Word,
+ dst_type_id: spirv::Word,
+ src: spirv::Word,
+) -> Result<(), dr::Error> {
+ let low_bits = builder.composite_extract(instruction_type, None, src, [0])?;
+ let high_bits = builder.composite_extract(instruction_type, None, src, [1])?;
+ let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
+ let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?;
+ builder.bitcast(dst_type_id, Some(dst), vector)?;
+ Ok(())
+}
+
fn emit_abs(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1844,8 +1961,8 @@ impl PtxSpecialRegister { fn get_builtin(self) -> spirv::BuiltIn {
match self {
- PtxSpecialRegister::Tid => spirv::BuiltIn::GlobalInvocationId,
- PtxSpecialRegister::Ntid => spirv::BuiltIn::GlobalSize,
+ PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId,
+ PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize,
PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId,
PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups,
}
@@ -2492,6 +2609,10 @@ impl<T: ArgParamsEx> ast::Instruction<T> { let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
ast::Instruction::Cvta(d, a.map(visitor, false, inst_type)?)
}
+ ast::Instruction::Mad(d, a) => {
+ let inst_type = d.get_type();
+ ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
+ }
})
}
}
@@ -2641,7 +2762,8 @@ impl ast::Instruction<ExpandedArgParams> { | ast::Instruction::St(_, _)
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
- | ast::Instruction::Call(_) => None,
+ | ast::Instruction::Call(_)
+ | ast::Instruction::Mad(_, _) => None,
}
}
}
@@ -2741,6 +2863,17 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> { }
}
+impl ast::VariableParamType {
+ fn width(self) -> usize {
+ match self {
+ ast::VariableParamType::Scalar(t) => ast::ScalarType::from(t).width() as usize,
+ ast::VariableParamType::Array(t, len) => {
+ (ast::ScalarType::from(t).width() as usize) * (len as usize)
+ }
+ }
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
@@ -3042,6 +3175,53 @@ impl<T: ArgParamsEx> ast::Arg4<T> { visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg4Setp<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::Type,
+ ) -> Result<ast::Arg4Setp<U>, TranslateError> {
let dst1 = visitor.variable(
ArgumentDescriptor {
op: self.dst1,
@@ -3079,7 +3259,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> { },
t,
)?;
- Ok(ast::Arg4 {
+ Ok(ast::Arg4Setp {
dst1,
dst2,
src1,
|