summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-26 01:49:25 +0100
committerAndrzej Janik <[email protected]>2020-10-26 01:49:25 +0100
commit40bdb83e6b80c169e9ab38e332dc3d633e8b0066 (patch)
tree64cc14dd06d3ed8ae0da18b728657b72487972cd
parent17b788f2a70fa78be945878b52ef497f5b76b5b1 (diff)
downloadZLUDA-40bdb83e6b80c169e9ab38e332dc3d633e8b0066.tar.gz
ZLUDA-40bdb83e6b80c169e9ab38e332dc3d633e8b0066.zip
Support float constants
-rw-r--r--ptx/src/ast.rs56
-rw-r--r--ptx/src/ptx.lalrpop158
-rw-r--r--ptx/src/test/spirv_run/constant_f32.ptx21
-rw-r--r--ptx/src/test/spirv_run/constant_f32.spvtxt57
-rw-r--r--ptx/src/test/spirv_run/constant_negative.ptx21
-rw-r--r--ptx/src/test/spirv_run/constant_negative.spvtxt56
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/translate.rs96
8 files changed, 385 insertions, 82 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index b045a83..d858d06 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -622,16 +622,24 @@ pub struct Arg5<P: ArgParams> {
}
#[derive(Copy, Clone)]
+pub enum ImmediateValue {
+ U64(u64),
+ S64(i64),
+ F32(f32),
+ F64(f64),
+}
+
+#[derive(Copy, Clone)]
pub enum Operand<ID> {
Reg(ID),
RegOffset(ID, i32),
- Imm(u32),
+ Imm(ImmediateValue),
}
#[derive(Copy, Clone)]
pub enum CallOperand<ID> {
Reg(ID),
- Imm(u32),
+ Imm(ImmediateValue),
}
pub enum IdOrVector<ID> {
@@ -642,7 +650,7 @@ pub enum IdOrVector<ID> {
pub enum OperandOrVector<ID> {
Reg(ID),
RegOffset(ID, i32),
- Imm(u32),
+ Imm(ImmediateValue),
Vec(Vec<ID>),
}
@@ -1028,7 +1036,7 @@ pub struct MinMaxFloat {
}
pub enum NumsOrArrays<'a> {
- Nums(Vec<&'a str>),
+ Nums(Vec<(&'a str, u32)>),
Arrays(Vec<NumsOrArrays<'a>>),
}
@@ -1076,8 +1084,8 @@ impl<'a> NumsOrArrays<'a> {
if vec.len() > *dim as usize {
return Err(PtxError::ZeroDimensionArray);
}
- for (idx, val) in vec.iter().enumerate() {
- Self::parse_and_copy_single(t, idx, val, result)?;
+ for (idx, (val, radix)) in vec.iter().enumerate() {
+ Self::parse_and_copy_single(t, idx, val, *radix, result)?;
}
}
NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray),
@@ -1107,42 +1115,43 @@ impl<'a> NumsOrArrays<'a> {
t: SizedScalarType,
idx: usize,
str_val: &str,
+ radix: u32,
output: &mut [u8],
) -> Result<(), PtxError> {
match t {
SizedScalarType::B8 | SizedScalarType::U8 => {
- Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
}
SizedScalarType::B16 | SizedScalarType::U16 => {
- Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
}
SizedScalarType::B32 | SizedScalarType::U32 => {
- Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
}
SizedScalarType::B64 | SizedScalarType::U64 => {
- Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
}
SizedScalarType::S8 => {
- Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
}
SizedScalarType::S16 => {
- Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
}
SizedScalarType::S32 => {
- Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
}
SizedScalarType::S64 => {
- Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
}
SizedScalarType::F16 => {
- Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
}
SizedScalarType::F16x2 => todo!(),
SizedScalarType::F32 => {
- Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
}
SizedScalarType::F64 => {
- Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?;
+ Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
}
}
Ok(())
@@ -1151,6 +1160,7 @@ impl<'a> NumsOrArrays<'a> {
fn parse_and_copy_single_t<T: Copy + FromStr>(
idx: usize,
str_val: &str,
+ _radix: u32, // TODO: use this to properly support hex literals
output: &mut [u8],
) -> Result<(), PtxError>
where
@@ -1200,8 +1210,8 @@ mod tests {
#[test]
fn array_auto_sizes_0_dimension() {
let inp = NumsOrArrays::Arrays(vec![
- NumsOrArrays::Nums(vec!["1", "2"]),
- NumsOrArrays::Nums(vec!["3", "4"]),
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]),
+ NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]),
]);
let mut dimensions = vec![0u32, 2];
assert_eq!(
@@ -1214,8 +1224,8 @@ mod tests {
#[test]
fn array_fails_wrong_structure() {
let inp = NumsOrArrays::Arrays(vec![
- NumsOrArrays::Nums(vec!["1", "2"]),
- NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]),
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]),
+ NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
@@ -1224,8 +1234,8 @@ mod tests {
#[test]
fn array_fails_too_long_component() {
let inp = NumsOrArrays::Arrays(vec![
- NumsOrArrays::Nums(vec!["1", "2", "3"]),
- NumsOrArrays::Nums(vec!["4", "5"]),
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]),
+ NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 163a233..d445baa 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -15,12 +15,16 @@ match {
r"\s+" => { },
r"//[^\n\r]*[\n\r]*" => { },
r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { },
- r"-?[?:0x]?[0-9]+" => Num,
+ r"0[fF][0-9a-zA-Z]{8}" => F32NumToken,
+ r"0[dD][0-9a-zA-Z]{16}" => F64NumToken,
+ r"0[xX][0-9a-zA-Z]+U?" => HexNumToken,
+ r"[0-9]+U?" => DecimalNumToken,
r#""[^"]*""# => String,
r"[0-9]+\.[0-9]+" => VersionNumber,
"!",
"(", ")",
"+",
+ "-",
",",
".",
":",
@@ -181,6 +185,74 @@ ExtendedID : &'input str = {
ID
}
+NumToken: (&'input str, u32, bool) = {
+ <s:HexNumToken> => {
+ if s.ends_with('U') {
+ (&s[2..s.len() - 1], 16, true)
+ } else {
+ (&s[2..], 16, false)
+ }
+ },
+ <s:DecimalNumToken> => {
+ let radix = if s.starts_with('0') { 8 } else { 10 };
+ if s.ends_with('U') {
+ (&s[..s.len() - 1], radix, true)
+ } else {
+ (s, radix, false)
+ }
+ }
+}
+
+F32Num: f32 = {
+ <s:F32NumToken> =>? {
+ match u32::from_str_radix(&s[2..], 16) {
+ Ok(x) => Ok(unsafe { std::mem::transmute::<_, f32>(x) }),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+
+ }
+}
+
+F64Num: f64 = {
+ <s:F64NumToken> =>? {
+ match u64::from_str_radix(&s[2..], 16) {
+ Ok(x) => Ok(unsafe { std::mem::transmute::<_, f64>(x) }),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+ }
+}
+
+U8Num: u8 = {
+ <x:NumToken> =>? {
+ let (text, radix, _) = x;
+ match u8::from_str_radix(text, radix) {
+ Ok(x) => Ok(x),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+ }
+}
+
+U32Num: u32 = {
+ <x:NumToken> =>? {
+ let (text, radix, _) = x;
+ match u32::from_str_radix(text, radix) {
+ Ok(x) => Ok(x),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+ }
+}
+
+// TODO: handle negative number properly
+S32Num: i32 = {
+ <sign:"-"?> <x:NumToken> =>? {
+ let (text, radix, _) = x;
+ match i32::from_str_radix(text, radix) {
+ Ok(x) => Ok(if sign.is_some() { -x } else { x }),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+ }
+}
+
pub Module: ast::Module<'input> = {
<v:Version> Target <d:Directive*> => {
ast::Module { version: v, directives: without_none(d) }
@@ -218,7 +290,7 @@ Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
};
AddressSize = {
- ".address_size" Num
+ ".address_size" U8Num
};
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
@@ -328,7 +400,7 @@ DebugDirective: () = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-loc
DebugLocation = {
- ".loc" Num Num Num
+ ".loc" U32Num U32Num U32Num
};
Label: &'input str = {
@@ -336,10 +408,7 @@ Label: &'input str = {
};
Align: u32 = {
- ".align" <a:Num> => {
- let align = a.parse::<u32>();
- align.unwrap_with(errors)
- }
+ ".align" <x:U32Num> => x
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
@@ -348,10 +417,7 @@ MultiVariable: ast::MultiVariable<&'input str> = {
}
VariableParam: u32 = {
- "<" <n:Num> ">" => {
- let size = n.parse::<u32>();
- size.unwrap_with(errors)
- }
+ "<" <n:U32Num> ">" => n
}
Variable: ast::Variable<ast::VariableType, &'input str> = {
@@ -1239,28 +1305,50 @@ ArithFloat: ast::ArithFloat = {
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
- <r:ExtendedID> "+" <o:Num> => {
- let offset = o.parse::<i32>();
- let offset = offset.unwrap_with(errors);
- ast::Operand::RegOffset(r, offset)
- },
- // TODO: start parsing whole constants sub-language:
- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
- <o:Num> => {
- let offset = o.parse::<u32>();
- let offset = offset.unwrap_with(errors);
- ast::Operand::Imm(offset)
- }
+ <r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset),
+ <x:ImmediateValue> => ast::Operand::Imm(x)
};
CallOperand: ast::CallOperand<&'input str> = {
<r:ExtendedID> => ast::CallOperand::Reg(r),
- <o:Num> => {
- let offset = o.parse::<u32>();
- let offset = offset.unwrap_with(errors);
- ast::CallOperand::Imm(offset)
+ <x:ImmediateValue> => ast::CallOperand::Imm(x)
+};
+
+// TODO: start parsing whole constants sub-language:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
+ImmediateValue: ast::ImmediateValue = {
+ // TODO: treat negation correctly
+ <neg:"-"?> <x:NumToken> =>? {
+ let (num, radix, is_unsigned) = x;
+ if neg.is_some() {
+ match i64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) })
+ }
+ } else if is_unsigned {
+ match u64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::U64(x)),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) })
+ }
+ } else {
+ match i64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::S64(x)),
+ Err(_) => {
+ match u64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::U64(x)),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::ParseInt(err) })
+ }
+ }
+ }
+ }
+ },
+ <f:F32Num> => {
+ ast::ImmediateValue::F32(f)
+ },
+ <f:F64Num> => {
+ ast::ImmediateValue::F64(f)
}
-};
+}
Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>}
@@ -1332,7 +1420,7 @@ VectorPrefix: u8 = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file
File = {
- ".file" Num String ("," Num "," Num)?
+ ".file" U32Num String ("," U32Num "," U32Num)?
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-section
@@ -1341,11 +1429,11 @@ Section = {
};
SectionDwarfLines: () = {
- BitType Comma<Num>,
+ BitType Comma<U32Num>,
".b32" SectionLabel,
".b64" SectionLabel,
- ".b32" SectionLabel "+" Num,
- ".b64" SectionLabel "+" Num,
+ ".b32" SectionLabel "+" U32Num,
+ ".b64" SectionLabel "+" U32Num,
};
SectionLabel = {
@@ -1409,9 +1497,7 @@ ArrayEmptyDimension = {
}
ArrayDimension: u32 = {
- "[" <n:Num> "]" =>? {
- str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
- }
+ "[" <n:U32Num> "]" => n,
}
ArrayInitializer: ast::NumsOrArrays<'input> = {
@@ -1424,7 +1510,7 @@ NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
NumsOrArrays: ast::NumsOrArrays<'input> = {
<n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
- <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
+ <n:CommaNonEmpty<NumToken>> => ast::NumsOrArrays::Nums(n.into_iter().map(|(x,radix,_)| (x, radix)).collect()),
}
Comma<T>: Vec<T> = {
diff --git a/ptx/src/test/spirv_run/constant_f32.ptx b/ptx/src/test/spirv_run/constant_f32.ptx
new file mode 100644
index 0000000..8894658
--- /dev/null
+++ b/ptx/src/test/spirv_run/constant_f32.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry constant_f32(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp, [in_addr];
+ mul.f32 temp, temp, 0f3f000000; // 0.5
+ st.f32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt
new file mode 100644
index 0000000..905bec4
--- /dev/null
+++ b/ptx/src/test/spirv_run/constant_f32.spvtxt
@@ -0,0 +1,57 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 32
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+OpCapability FunctionFloatControlINTEL
+OpExtension "SPV_INTEL_float_controls2"
+%24 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "constant_f32"
+OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+%25 = OpTypeVoid
+%26 = OpTypeInt 64 0
+%27 = OpTypeFunction %25 %26 %26
+%28 = OpTypePointer Function %26
+%29 = OpTypeFloat 32
+%30 = OpTypePointer Function %29
+%31 = OpTypePointer Generic %29
+%19 = OpConstant %29 0.5
+%1 = OpFunction %25 None %27
+%7 = OpFunctionParameter %26
+%8 = OpFunctionParameter %26
+%22 = OpLabel
+%2 = OpVariable %28 Function
+%3 = OpVariable %28 Function
+%4 = OpVariable %28 Function
+%5 = OpVariable %28 Function
+%6 = OpVariable %30 Function
+OpStore %2 %7
+OpStore %3 %8
+%10 = OpLoad %26 %2
+%9 = OpCopyObject %26 %10
+OpStore %4 %9
+%12 = OpLoad %26 %3
+%11 = OpCopyObject %26 %12
+OpStore %5 %11
+%14 = OpLoad %26 %4
+%20 = OpConvertUToPtr %31 %14
+%13 = OpLoad %29 %20
+OpStore %6 %13
+%16 = OpLoad %29 %6
+%15 = OpFMul %29 %16 %19
+OpStore %6 %15
+%17 = OpLoad %26 %5
+%18 = OpLoad %29 %6
+%21 = OpConvertUToPtr %31 %17
+OpStore %21 %18
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/constant_negative.ptx b/ptx/src/test/spirv_run/constant_negative.ptx
new file mode 100644
index 0000000..c723c38
--- /dev/null
+++ b/ptx/src/test/spirv_run/constant_negative.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry constant_negative(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp, [in_addr];
+ mul.lo.s32 temp, temp, -1;
+ st.s32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt
new file mode 100644
index 0000000..39e5d19
--- /dev/null
+++ b/ptx/src/test/spirv_run/constant_negative.spvtxt
@@ -0,0 +1,56 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 32
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+OpCapability FunctionFloatControlINTEL
+OpExtension "SPV_INTEL_float_controls2"
+%24 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "constant_negative"
+%25 = OpTypeVoid
+%26 = OpTypeInt 64 0
+%27 = OpTypeFunction %25 %26 %26
+%28 = OpTypePointer Function %26
+%29 = OpTypeInt 32 0
+%30 = OpTypePointer Function %29
+%31 = OpTypePointer Generic %29
+%19 = OpConstant %29 4294967295
+%1 = OpFunction %25 None %27
+%7 = OpFunctionParameter %26
+%8 = OpFunctionParameter %26
+%22 = OpLabel
+%2 = OpVariable %28 Function
+%3 = OpVariable %28 Function
+%4 = OpVariable %28 Function
+%5 = OpVariable %28 Function
+%6 = OpVariable %30 Function
+OpStore %2 %7
+OpStore %3 %8
+%10 = OpLoad %26 %2
+%9 = OpCopyObject %26 %10
+OpStore %4 %9
+%12 = OpLoad %26 %3
+%11 = OpCopyObject %26 %12
+OpStore %5 %11
+%14 = OpLoad %26 %4
+%20 = OpConvertUToPtr %31 %14
+%13 = OpLoad %29 %20
+OpStore %6 %13
+%16 = OpLoad %29 %6
+%15 = OpIMul %29 %16 %19
+OpStore %6 %15
+%17 = OpLoad %26 %5
+%18 = OpLoad %29 %6
+%21 = OpConvertUToPtr %31 %17
+OpStore %21 %18
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 658d2ef..40acd46 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -87,6 +87,8 @@ test_ptx!(rcp, [2f32], [0.5f32]);
// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
+test_ptx!(constant_f32, [10f32], [5f32]);
+test_ptx!(constant_negative, [-101i32], [101i32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 20b5159..c0ff8f0 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1681,7 +1681,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: -(offset as i64),
+ value: ast::ImmediateValue::S64(-(offset as i64)),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@@ -1697,7 +1697,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: offset as i64,
+ value: ast::ImmediateValue::S64(offset as i64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
@@ -1724,7 +1724,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn immediate(
&mut self,
- desc: ArgumentDescriptor<u32>,
+ desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
@@ -1736,7 +1736,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
- value: desc.op as i64,
+ value: desc.op,
}));
Ok(id)
}
@@ -2081,32 +2081,82 @@ fn emit_function_body_ops(
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
- match cnst.typ {
- ast::ScalarType::B8 | ast::ScalarType::U8 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u8 as u32);
+ match (cnst.typ, cnst.value) {
+ (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
}
- ast::ScalarType::B16 | ast::ScalarType::U16 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u16 as u32);
+ (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
}
- ast::ScalarType::B32 | ast::ScalarType::U32 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as u32);
+ (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
}
- ast::ScalarType::B64 | ast::ScalarType::U64 => {
- builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as u64);
+ (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value);
}
- ast::ScalarType::S8 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i8 as u32);
+ (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
}
- ast::ScalarType::S16 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i16 as u32);
+ (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
}
- ast::ScalarType::S32 => {
- builder.constant_u32(typ_id, Some(cnst.dst), cnst.value as i32 as u32);
+ (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
}
- ast::ScalarType::S64 => {
- builder.constant_u64(typ_id, Some(cnst.dst), cnst.value as i64 as u64);
+ (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64);
}
- _ => unreachable!(),
+ (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32());
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), value);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f64(typ_id, Some(cnst.dst), value as f64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32());
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id, Some(cnst.dst), value as f32);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f64(typ_id, Some(cnst.dst), value);
+ }
+ _ => return Err(TranslateError::MismatchedType),
}
}
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
@@ -4371,7 +4421,7 @@ impl VisitVariableExpanded for CompositeRead {
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
- pub value: i64,
+ pub value: ast::ImmediateValue,
}
struct BrachCondition {