summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-18 01:50:29 +0100
committerAndrzej Janik <[email protected]>2020-11-18 01:50:29 +0100
commitac45c2bde42cef74d68b572920b0749f8eea6837 (patch)
tree1a425dcce73e5caafcb891e68e843ea3072282f2
parent106ed74cb79b09a24964967ed7a9b4b13aff1b98 (diff)
downloadZLUDA-ac45c2bde42cef74d68b572920b0749f8eea6837.tar.gz
ZLUDA-ac45c2bde42cef74d68b572920b0749f8eea6837.zip
Optimize operations involving the statefuls
-rw-r--r--ptx/src/ast.rs6
-rw-r--r--ptx/src/translate.rs129
2 files changed, 135 insertions, 0 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 5a5f6be..367f060 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -766,6 +766,8 @@ sub_type! {
LdStType {
Scalar(LdStScalarType),
Vector(LdStScalarType, u8),
+ // Used in generated code
+ Pointer(PointerType, LdStateSpace),
}
}
@@ -774,6 +776,10 @@ impl From<LdStType> for PointerType {
match t {
LdStType::Scalar(t) => PointerType::Scalar(t.into()),
LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
+ LdStType::Pointer(PointerType::Scalar(scalar_type), space) => {
+ PointerType::Pointer(scalar_type, space)
+ }
+ LdStType::Pointer(..) => unreachable!(),
}
}
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index f644a27..29fa93e 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -4397,6 +4397,105 @@ fn convert_to_stateful_memory_access<'a>(
result.push(Statement::Variable(var));
}
}
+ Statement::Instruction(ast::Instruction::Cvta(
+ ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Global,
+ size: ast::CvtaSize::U64,
+ from: ast::CvtaStateSpace::Generic,
+ },
+ arg,
+ )) if is_cvta_ptr_direct(&remapped_ids, &arg) => {
+ let new_dst = *remapped_ids.get(&arg.dst).unwrap();
+ let new_src = *remapped_ids.get(&arg.src.underlying().unwrap()).unwrap();
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ ast::MovDetails {
+ typ: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ },
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal {
+ dst: ast::IdOrVector::Reg(new_dst),
+ src: ast::OperandOrVector::Reg(new_src),
+ }),
+ )));
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ details
+ @
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ ..
+ },
+ arg,
+ )) if is_param_ld_ptr_direct(&remapped_ids, &func_args_ptr, &arg) => {
+ let new_dst = if let ast::IdOrVector::Reg(dst) = arg.dst {
+ *remapped_ids.get(&dst).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ typ: ast::LdStType::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ ..details
+ },
+ ast::Arg2Ld {
+ src: arg.src,
+ dst: ast::IdOrVector::Reg(new_dst),
+ },
+ )));
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ details
+ @
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Global,
+ ..
+ },
+ arg,
+ )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src) => {
+ let new_src = if let ast::Operand::Reg(src) = arg.src {
+ *remapped_ids.get(&src).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ details,
+ ast::Arg2Ld {
+ src: ast::Operand::Reg(new_src),
+ ..arg
+ },
+ )));
+ }
+ Statement::Instruction(ast::Instruction::St(
+ details
+ @
+ ast::StData {
+ state_space: ast::StStateSpace::Global,
+ ..
+ },
+ arg,
+ )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src1) => {
+ let new_src1 = if let ast::Operand::Reg(src1) = arg.src1 {
+ *remapped_ids.get(&src1).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::St(
+ details,
+ ast::Arg2St {
+ src1: ast::Operand::Reg(new_src1),
+ ..arg
+ },
+ )));
+ }
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
let new_statement =
@@ -4489,6 +4588,36 @@ fn convert_to_stateful_memory_access<'a>(
Ok(result)
}
+fn is_param_ld_ptr_direct(
+ remapped_ids: &HashMap<u32, u32>,
+ func_args_ptr: &HashSet<u32>,
+ arg: &ast::Arg2Ld<TypedArgParams>,
+) -> bool {
+ match (arg.src.underlying(), &arg.dst) {
+ (Some(src), ast::IdOrVector::Reg(dst)) => {
+ func_args_ptr.contains(src) && remapped_ids.contains_key(dst)
+ }
+ _ => false,
+ }
+}
+
+fn is_cvta_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg2<TypedArgParams>) -> bool {
+ match arg.src.underlying() {
+ Some(src) => remapped_ids.contains_key(src) && remapped_ids.contains_key(&arg.dst),
+ None => false,
+ }
+}
+
+fn is_ldst_global_ptr_direct(
+ remapped_ids: &HashMap<u32, u32>,
+ src: &ast::Operand<spirv::Word>,
+) -> bool {
+ match src.underlying() {
+ Some(src) => remapped_ids.contains_key(src),
+ None => false,
+ }
+}
+
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
Ok((ast::Type::Scalar(ast::ScalarType::U64), _))