diff options
Diffstat (limited to 'ptx/src/ptx.lalrpop')
-rw-r--r-- | ptx/src/ptx.lalrpop | 71 |
1 files changed, 41 insertions, 30 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index a132705..163a233 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -740,17 +740,29 @@ InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = { }; SetpMode: ast::SetpData = { - <cmp_op:SetpCompareOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpData{ + <cmp_op:SetpCompareOp> <t:SetpTypeNoF32> => ast::SetpData { typ: t, - flush_to_zero: ftz.is_some(), + flush_to_zero: None, + cmp_op: cmp_op, + }, + <cmp_op:SetpCompareOp> <ftz:".ftz"?> ".f32" => ast::SetpData { + typ: ast::ScalarType::F32, + flush_to_zero: Some(ftz.is_some()), cmp_op: cmp_op, } + }; SetpBoolMode: ast::SetpBoolData = { - <cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpBoolData{ + <cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <t:SetpTypeNoF32> => ast::SetpBoolData { typ: t, - flush_to_zero: ftz.is_some(), + flush_to_zero: None, + cmp_op: cmp_op, + bool_op: bool_op, + }, + <cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> ".f32" => ast::SetpBoolData { + typ: ast::ScalarType::F32, + flush_to_zero: Some(ftz.is_some()), cmp_op: cmp_op, bool_op: bool_op, } @@ -783,7 +795,7 @@ SetpBoolPostOp: ast::SetpBoolPostOp = { ".xor" => ast::SetpBoolPostOp::Xor, }; -SetpType: ast::ScalarType = { +SetpTypeNoF32: ast::ScalarType = { ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, ".b64" => ast::ScalarType::B64, @@ -793,7 +805,6 @@ SetpType: ast::ScalarType = { ".s16" => ast::ScalarType::S16, ".s32" => ast::ScalarType::S32, ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, ".f64" => ast::ScalarType::F64, }; @@ -857,7 +868,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F16 @@ -868,7 +879,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F16 @@ -879,7 +890,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F16 @@ -890,7 +901,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F32 @@ -901,7 +912,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: f.is_some(), + flush_to_zero: Some(f.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F32 @@ -912,7 +923,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: None, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F32 @@ -923,7 +934,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F16, src: ast::FloatType::F64 @@ -934,7 +945,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: Some(r), - flush_to_zero: s.is_some(), + flush_to_zero: Some(s.is_some()), saturate: s.is_some(), dst: ast::FloatType::F32, src: ast::FloatType::F64 @@ -945,7 +956,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( ast::CvtDesc { rounding: r, - flush_to_zero: false, + flush_to_zero: None, saturate: s.is_some(), dst: ast::FloatType::F64, src: ast::FloatType::F64 @@ -1082,19 +1093,19 @@ InstCall: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = { "abs" <t:SignedIntType> <a:Arg2> => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: t }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: t }, a) }, "abs" <f:".ftz"?> ".f32" <a:Arg2> => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F32 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F32 }, a) }, "abs" ".f64" <a:Arg2> => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: ast::ScalarType::F64 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: ast::ScalarType::F64 }, a) }, "abs" <f:".ftz"?> ".f16" <a:Arg2> => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F16 }, a) + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16 }, a) }, "abs" <f:".ftz"?> ".f16x2" <a:Arg2> => { - todo!() + ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16x2 }, a) }, }; @@ -1128,7 +1139,7 @@ InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = { "rcp" <rounding:RcpRoundingMode> <ftz:".ftz"?> ".f32" <a:Arg2> => { let details = ast::RcpDetails { rounding, - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), is_f64: false, }; ast::Instruction::Rcp(details, a) @@ -1136,7 +1147,7 @@ InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = { "rcp" <rn:RoundingModeFloat> ".f64" <a:Arg2> => { let details = ast::RcpDetails { rounding: Some(rn), - flush_to_zero: false, + flush_to_zero: None, is_f64: true, }; ast::Instruction::Rcp(details, a) @@ -1173,16 +1184,16 @@ MinMaxDetails: ast::MinMaxDetails = { <t:UIntType> => ast::MinMaxDetails::Unsigned(t), <t:SIntType> => ast::MinMaxDetails::Signed(t), <ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } ), <ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } ), <ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } ) } @@ -1203,25 +1214,25 @@ ArithFloat: ast::ArithFloat = { <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { typ: ast::FloatType::F32, rounding: rn, - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, <rn:RoundingModeFloat?> ".f64" => ast::ArithFloat { typ: ast::FloatType::F64, rounding: rn, - flush_to_zero: false, + flush_to_zero: None, saturate: false, }, <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { typ: ast::FloatType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { typ: ast::FloatType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), + flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, } |