diff options
author | Andrzej Janik <[email protected]> | 2020-10-26 01:49:25 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-10-26 01:49:25 +0100 |
commit | 40bdb83e6b80c169e9ab38e332dc3d633e8b0066 (patch) | |
tree | 64cc14dd06d3ed8ae0da18b728657b72487972cd | |
parent | 17b788f2a70fa78be945878b52ef497f5b76b5b1 (diff) | |
download | ZLUDA-40bdb83e6b80c169e9ab38e332dc3d633e8b0066.tar.gz ZLUDA-40bdb83e6b80c169e9ab38e332dc3d633e8b0066.zip |
Support float constants
-rw-r--r-- | ptx/src/ast.rs | 56 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 158 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/constant_f32.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/constant_f32.spvtxt | 57 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/constant_negative.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/constant_negative.spvtxt | 56 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 2 | ||||
-rw-r--r-- | ptx/src/translate.rs | 96 |
8 files changed, 385 insertions, 82 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index b045a83..d858d06 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -622,16 +622,24 @@ pub struct Arg5<P: ArgParams> { } #[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} + +#[derive(Copy, Clone)] pub enum Operand<ID> { Reg(ID), RegOffset(ID, i32), - Imm(u32), + Imm(ImmediateValue), } #[derive(Copy, Clone)] pub enum CallOperand<ID> { Reg(ID), - Imm(u32), + Imm(ImmediateValue), } pub enum IdOrVector<ID> { @@ -642,7 +650,7 @@ pub enum IdOrVector<ID> { pub enum OperandOrVector<ID> { Reg(ID), RegOffset(ID, i32), - Imm(u32), + Imm(ImmediateValue), Vec(Vec<ID>), } @@ -1028,7 +1036,7 @@ pub struct MinMaxFloat { } pub enum NumsOrArrays<'a> { - Nums(Vec<&'a str>), + Nums(Vec<(&'a str, u32)>), Arrays(Vec<NumsOrArrays<'a>>), } @@ -1076,8 +1084,8 @@ impl<'a> NumsOrArrays<'a> { if vec.len() > *dim as usize { return Err(PtxError::ZeroDimensionArray); } - for (idx, val) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, result)?; + for (idx, (val, radix)) in vec.iter().enumerate() { + Self::parse_and_copy_single(t, idx, val, *radix, result)?; } } NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), @@ -1107,42 +1115,43 @@ impl<'a> NumsOrArrays<'a> { t: SizedScalarType, idx: usize, str_val: &str, + radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { SizedScalarType::B8 | SizedScalarType::U8 => { - Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?; } SizedScalarType::B16 | SizedScalarType::U16 => { - Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?; } SizedScalarType::B32 | SizedScalarType::U32 => { - Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?; } SizedScalarType::B64 | SizedScalarType::U64 => { - Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?; } SizedScalarType::S8 => { - Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?; } SizedScalarType::S16 => { - Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?; } SizedScalarType::S32 => { - Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?; } SizedScalarType::S64 => { - Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?; } SizedScalarType::F16 => { - Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?; } SizedScalarType::F16x2 => todo!(), SizedScalarType::F32 => { - Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?; } SizedScalarType::F64 => { - Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?; + Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?; } } Ok(()) @@ -1151,6 +1160,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy_single_t<T: Copy + FromStr>( idx: usize, str_val: &str, + _radix: u32, // TODO: use this to properly support hex literals output: &mut [u8], ) -> Result<(), PtxError> where @@ -1200,8 +1210,8 @@ mod tests { #[test] fn array_auto_sizes_0_dimension() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2"]), - NumsOrArrays::Nums(vec!["3", "4"]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), + NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), ]); let mut dimensions = vec![0u32, 2]; assert_eq!( @@ -1214,8 +1224,8 @@ mod tests { #[test] fn array_fails_wrong_structure() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2"]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), + NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); @@ -1224,8 +1234,8 @@ mod tests { #[test] fn array_fails_too_long_component() { let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec!["1", "2", "3"]), - NumsOrArrays::Nums(vec!["4", "5"]), + NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), + NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 163a233..d445baa 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -15,12 +15,16 @@ match { r"\s+" => { }, r"//[^\n\r]*[\n\r]*" => { }, r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { }, - r"-?[?:0x]?[0-9]+" => Num, + r"0[fF][0-9a-zA-Z]{8}" => F32NumToken, + r"0[dD][0-9a-zA-Z]{16}" => F64NumToken, + r"0[xX][0-9a-zA-Z]+U?" => HexNumToken, + r"[0-9]+U?" => DecimalNumToken, r#""[^"]*""# => String, r"[0-9]+\.[0-9]+" => VersionNumber, "!", "(", ")", "+", + "-", ",", ".", ":", @@ -181,6 +185,74 @@ ExtendedID : &'input str = { ID } +NumToken: (&'input str, u32, bool) = { + <s:HexNumToken> => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + }, + <s:DecimalNumToken> => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } +} + +F32Num: f32 = { + <s:F32NumToken> =>? { + match u32::from_str_radix(&s[2..], 16) { + Ok(x) => Ok(unsafe { std::mem::transmute::<_, f32>(x) }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + + } +} + +F64Num: f64 = { + <s:F64NumToken> =>? { + match u64::from_str_radix(&s[2..], 16) { + Ok(x) => Ok(unsafe { std::mem::transmute::<_, f64>(x) }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +U8Num: u8 = { + <x:NumToken> =>? { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +U32Num: u32 = { + <x:NumToken> =>? { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + +// TODO: handle negative number properly +S32Num: i32 = { + <sign:"-"?> <x:NumToken> =>? { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) }) + } + } +} + pub Module: ast::Module<'input> = { <v:Version> Target <d:Directive*> => { ast::Module { version: v, directives: without_none(d) } @@ -218,7 +290,7 @@ Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = { }; AddressSize = { - ".address_size" Num + ".address_size" U8Num }; Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { @@ -328,7 +400,7 @@ DebugDirective: () = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-loc DebugLocation = { - ".loc" Num Num Num + ".loc" U32Num U32Num U32Num }; Label: &'input str = { @@ -336,10 +408,7 @@ Label: &'input str = { }; Align: u32 = { - ".align" <a:Num> => { - let align = a.parse::<u32>(); - align.unwrap_with(errors) - } + ".align" <x:U32Num> => x }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names @@ -348,10 +417,7 @@ MultiVariable: ast::MultiVariable<&'input str> = { } VariableParam: u32 = { - "<" <n:Num> ">" => { - let size = n.parse::<u32>(); - size.unwrap_with(errors) - } + "<" <n:U32Num> ">" => n } Variable: ast::Variable<ast::VariableType, &'input str> = { @@ -1239,28 +1305,50 @@ ArithFloat: ast::ArithFloat = { Operand: ast::Operand<&'input str> = { <r:ExtendedID> => ast::Operand::Reg(r), - <r:ExtendedID> "+" <o:Num> => { - let offset = o.parse::<i32>(); - let offset = offset.unwrap_with(errors); - ast::Operand::RegOffset(r, offset) - }, - // TODO: start parsing whole constants sub-language: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants - <o:Num> => { - let offset = o.parse::<u32>(); - let offset = offset.unwrap_with(errors); - ast::Operand::Imm(offset) - } + <r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset), + <x:ImmediateValue> => ast::Operand::Imm(x) }; CallOperand: ast::CallOperand<&'input str> = { <r:ExtendedID> => ast::CallOperand::Reg(r), - <o:Num> => { - let offset = o.parse::<u32>(); - let offset = offset.unwrap_with(errors); - ast::CallOperand::Imm(offset) + <x:ImmediateValue> => ast::CallOperand::Imm(x) +}; + +// TODO: start parsing whole constants sub-language: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants +ImmediateValue: ast::ImmediateValue = { + // TODO: treat negation correctly + <neg:"-"?> <x:NumToken> =>? { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) }) + } + } + } + } + }, + <f:F32Num> => { + ast::ImmediateValue::F32(f) + }, + <f:F64Num> => { + ast::ImmediateValue::F64(f) } -}; +} Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = { <src:ExtendedID> => ast::Arg1{<>} @@ -1332,7 +1420,7 @@ VectorPrefix: u8 = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file File = { - ".file" Num String ("," Num "," Num)? + ".file" U32Num String ("," U32Num "," U32Num)? }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-section @@ -1341,11 +1429,11 @@ Section = { }; SectionDwarfLines: () = { - BitType Comma<Num>, + BitType Comma<U32Num>, ".b32" SectionLabel, ".b64" SectionLabel, - ".b32" SectionLabel "+" Num, - ".b64" SectionLabel "+" Num, + ".b32" SectionLabel "+" U32Num, + ".b64" SectionLabel "+" U32Num, }; SectionLabel = { @@ -1409,9 +1497,7 @@ ArrayEmptyDimension = { } ArrayDimension: u32 = { - "[" <n:Num> "]" =>? { - str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) - } + "[" <n:U32Num> "]" => n, } ArrayInitializer: ast::NumsOrArrays<'input> = { @@ -1424,7 +1510,7 @@ NumsOrArraysBracket: ast::NumsOrArrays<'input> = { NumsOrArrays: ast::NumsOrArrays<'input> = { <n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n), - <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n), + <n:CommaNonEmpty<NumToken>> => ast::NumsOrArrays::Nums(n.into_iter().map(|(x,radix,_)| (x, radix)).collect()), } Comma<T>: Vec<T> = { diff --git a/ptx/src/test/spirv_run/constant_f32.ptx b/ptx/src/test/spirv_run/constant_f32.ptx new file mode 100644 index 0000000..8894658 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_f32.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry constant_f32(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp, [in_addr];
+ mul.f32 temp, temp, 0f3f000000; // 0.5
+ st.f32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt new file mode 100644 index 0000000..905bec4 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_f32.spvtxt @@ -0,0 +1,57 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 32 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%24 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "constant_f32" +OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +%25 = OpTypeVoid +%26 = OpTypeInt 64 0 +%27 = OpTypeFunction %25 %26 %26 +%28 = OpTypePointer Function %26 +%29 = OpTypeFloat 32 +%30 = OpTypePointer Function %29 +%31 = OpTypePointer Generic %29 +%19 = OpConstant %29 0.5 +%1 = OpFunction %25 None %27 +%7 = OpFunctionParameter %26 +%8 = OpFunctionParameter %26 +%22 = OpLabel +%2 = OpVariable %28 Function +%3 = OpVariable %28 Function +%4 = OpVariable %28 Function +%5 = OpVariable %28 Function +%6 = OpVariable %30 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %26 %2 +%9 = OpCopyObject %26 %10 +OpStore %4 %9 +%12 = OpLoad %26 %3 +%11 = OpCopyObject %26 %12 +OpStore %5 %11 +%14 = OpLoad %26 %4 +%20 = OpConvertUToPtr %31 %14 +%13 = OpLoad %29 %20 +OpStore %6 %13 +%16 = OpLoad %29 %6 +%15 = OpFMul %29 %16 %19 +OpStore %6 %15 +%17 = OpLoad %26 %5 +%18 = OpLoad %29 %6 +%21 = OpConvertUToPtr %31 %17 +OpStore %21 %18 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/constant_negative.ptx b/ptx/src/test/spirv_run/constant_negative.ptx new file mode 100644 index 0000000..c723c38 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_negative.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry constant_negative(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp, [in_addr];
+ mul.lo.s32 temp, temp, -1;
+ st.s32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt new file mode 100644 index 0000000..39e5d19 --- /dev/null +++ b/ptx/src/test/spirv_run/constant_negative.spvtxt @@ -0,0 +1,56 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 32 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%24 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "constant_negative" +%25 = OpTypeVoid +%26 = OpTypeInt 64 0 +%27 = OpTypeFunction %25 %26 %26 +%28 = OpTypePointer Function %26 +%29 = OpTypeInt 32 0 +%30 = OpTypePointer Function %29 +%31 = OpTypePointer Generic %29 +%19 = OpConstant %29 4294967295 +%1 = OpFunction %25 None %27 +%7 = OpFunctionParameter %26 +%8 = OpFunctionParameter %26 +%22 = OpLabel +%2 = OpVariable %28 Function +%3 = OpVariable %28 Function +%4 = OpVariable %28 Function +%5 = OpVariable %28 Function +%6 = OpVariable %30 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %26 %2 +%9 = OpCopyObject %26 %10 +OpStore %4 %9 +%12 = OpLoad %26 %3 +%11 = OpCopyObject %26 %12 +OpStore %5 %11 +%14 = OpLoad %26 %4 +%20 = OpConvertUToPtr %31 %14 +%13 = OpLoad %29 %20 +OpStore %6 %13 +%16 = OpLoad %29 %6 +%15 = OpIMul %29 %16 %19 +OpStore %6 %15 +%17 = OpLoad %26 %5 +%18 = OpLoad %29 %6 +%21 = OpConvertUToPtr %31 %17 +OpStore %21 %18 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 658d2ef..40acd46 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -87,6 +87,8 @@ test_ptx!(rcp, [2f32], [0.5f32]); // TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
+test_ptx!(constant_f32, [10f32], [5f32]);
+test_ptx!(constant_negative, [-101i32], [101i32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20b5159..c0ff8f0 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1681,7 +1681,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: -(offset as i64),
+ value: ast::ImmediateValue::S64(-(offset as i64)),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@@ -1697,7 +1697,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: offset as i64,
+ value: ast::ImmediateValue::S64(offset as i64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
@@ -1724,7 +1724,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn immediate(
&mut self,
- desc: ArgumentDescriptor<u32>,
+ desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
@@ -1736,7 +1736,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
- value: desc.op as i64,
+ value: desc.op,
}));
Ok(id)
}
@@ -2081,32 +2081,82 @@ fn emit_function_body_ops( }
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
- match cnst.typ {
- ast::ScalarType::B8 | ast::ScalarType::U8 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u8 as u32);
+ match (cnst.typ, cnst.value) {
+ (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
}
- ast::ScalarType::B16 | ast::ScalarType::U16 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u16 as u32);
+ (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
}
- ast::ScalarType::B32 | ast::ScalarType::U32 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u32);
+ (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
}
- ast::ScalarType::B64 | ast::ScalarType::U64 => {
- builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as u64);
+ (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value);
}
- ast::ScalarType::S8 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i8 as u32);
+ (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
}
- ast::ScalarType::S16 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i16 as u32);
+ (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
}
- ast::ScalarType::S32 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i32 as u32);
+ (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
}
- ast::ScalarType::S64 => {
- builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as i64 as u64);
+ (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64);
}
- _ => unreachable!(),
+ (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32());
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), value);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f64(typ_id, Some(cnst.dst), value as f64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32());
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), value as f32);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f64(typ_id, Some(cnst.dst), value);
+ }
+ _ => return Err(TranslateError::MismatchedType),
}
}
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
@@ -4371,7 +4421,7 @@ impl VisitVariableExpanded for CompositeRead { struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
- pub value: i64,
+ pub value: ast::ImmediateValue,
}
struct BrachCondition {
|