aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs162
1 files changed, 100 insertions, 62 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 5fea075..6c2c594 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
- emit_builtins(&mut builder, &mut map, &id_defs);
+ //emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives(
@@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>(
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements =
+ fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>(
}
fn fix_special_registers(
+ ptx_impl_imports: &mut HashMap<String, Directive>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1276,7 +1278,6 @@ fn fix_special_registers(
for s in typed_statements {
match s {
Statement::LoadVar(
- mut
details
@
LoadVarDetails {
@@ -1285,48 +1286,53 @@ fn fix_special_registers(
},
) => {
let index = details.member_index.unwrap().0;
- if index == 3 {
- result.push(Statement::Constant(ConstantDefinition {
- dst: details.arg.dst,
- typ: ast::ScalarType::U32,
- value: ast::ImmediateValue::U64(0),
- }));
- } else {
- let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
- {
- Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
- None => None,
- };
- let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
- Some(sreg_and_type) => sreg_and_type,
- None => {
- result.push(Statement::LoadVar(details));
- continue;
- }
- };
- let temp_id = numeric_id_defs
- .register_intermediate(Some((details.typ.clone(), details.state_space)));
- let real_dst = details.arg.dst;
- details.arg.dst = temp_id;
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- src: sreg_src,
- dst: temp_id,
- },
- state_space: ast::StateSpace::Sreg,
- typ: ast::Type::Scalar(scalar_typ),
- member_index: Some((index, Some(vector_width))),
- }));
- result.push(Statement::Conversion(ImplicitConversion {
- src: temp_id,
- dst: real_dst,
- from_type: ast::Type::Scalar(scalar_typ),
- from_space: ast::StateSpace::Sreg,
- to_type: ast::Type::Scalar(ast::ScalarType::U32),
- to_space: ast::StateSpace::Sreg,
- kind: ConversionKind::Default,
- }));
- }
+ let sreg = numeric_id_defs
+ .special_registers
+ .get(details.arg.src)
+ .ok_or_else(|| error_unreachable())?;
+ let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
+ let index_constant = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::Constant(ConstantDefinition {
+ dst: index_constant,
+ typ: ast::ScalarType::U32,
+ value: ast::ImmediateValue::U64(index as u64),
+ }));
+ let fn_result = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ocl_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments =
+ vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
+ let input_arguments = vec![(
+ TypedOperand::Reg(index_constant),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ numeric_id_defs,
+ ptx_impl_imports,
+ ocl_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ result.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ return_arguments,
+ name: fn_call,
+ input_arguments,
+ }));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: fn_result,
+ dst: details.arg.dst,
+ from_type: ast::Type::Scalar(ocl_type),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::Default,
+ }));
}
s => result.push(s),
}
@@ -1721,8 +1727,8 @@ fn instruction_to_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
- return_arguments,
- input_arguments,
+ return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
)?;
Ok(Statement::Call(ResolvedCall {
uniform: false,
@@ -1732,12 +1738,12 @@ fn instruction_to_fn_call(
}))
}
-fn register_external_fn_call(
+fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
- return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
- input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
@@ -1770,19 +1776,18 @@ fn register_external_fn_call(
}
}
-fn fn_arguments_to_variables(
+fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
- args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<spirv::Word>> {
- args.iter()
- .map(|(_, typ, space)| ast::Variable {
- align: None,
- v_type: typ.clone(),
- state_space: *space,
- name: id_defs.register_intermediate(None),
- array_init: Vec::new(),
- })
- .collect::<Vec<_>>()
+ args.map(|(typ, space)| ast::Variable {
+ align: None,
+ v_type: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
}
fn arguments_to_resolved_arguments(
@@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>(
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Constant(_) => return Err(error_unreachable()),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
}
}
Ok(result)
@@ -4686,6 +4691,19 @@ impl PtxSpecialRegister {
}
}
+ fn get_scalar_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Tid64
+ | PtxSpecialRegister::Ntid64
+ | PtxSpecialRegister::Ctaid64
+ | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
+ }
+ }
+
fn get_builtin(self) -> spirv::BuiltIn {
match self {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
@@ -4701,6 +4719,23 @@ impl PtxSpecialRegister {
}
}
+ fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
+ match self {
+ PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
+ ("_Z12get_local_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
+ ("_Z14get_local_sizej", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
+ ("_Z12get_group_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
+ ("_Z14get_num_groupsj", ast::ScalarType::U64)
+ }
+ }
+ }
+
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
match self {
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
@@ -4743,6 +4778,8 @@ impl SpecialRegistersMap {
}
fn interface(&self) -> Vec<spirv::Word> {
+ return Vec::new();
+ /*
self.reg_to_id
.iter()
.filter_map(|(sreg, id)| {
@@ -4753,6 +4790,7 @@ impl SpecialRegistersMap {
}
})
.collect::<Vec<_>>()
+ */
}
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {