diff options
Diffstat (limited to 'ptx/src/pass/extract_globals.rs')
-rw-r--r-- | ptx/src/pass/extract_globals.rs | 282 |
1 files changed, 282 insertions, 0 deletions
diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs new file mode 100644 index 0000000..680a5ee --- /dev/null +++ b/ptx/src/pass/extract_globals.rs @@ -0,0 +1,282 @@ +use super::*;
+
+pub(super) fn run<'input, 'b>(
+ sorted_statements: Vec<ExpandedStatement>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ id_def: &mut NumericIdResolver,
+) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
+ let mut local = Vec::with_capacity(sorted_statements.len());
+ let mut global = Vec::new();
+ for statement in sorted_statements {
+ match statement {
+ Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Shared,
+ ..
+ },
+ )
+ | Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Global,
+ ..
+ },
+ ) => global.push(var),
+ Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfe { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfi { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
+ let fn_name: String =
+ [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Brev { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Activemask { arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::IncrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_inc",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::DecrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_dec",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::FloatAdd,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let scalar_type = match data.type_ {
+ ptx_parser::Type::Scalar(scalar) => scalar,
+ _ => return Err(error_unreachable()),
+ };
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_add_",
+ scalar_to_ptx_name(scalar_type),
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ s => local.push(s),
+ }
+ }
+ Ok((local, global))
+}
+
+fn instruction_to_fn_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ inst: ast::Instruction<SpirvWord>,
+ fn_name: String,
+) -> Result<ExpandedStatement, TranslateError> {
+ let mut arguments = Vec::new();
+ ast::visit_map(inst, &mut |operand,
+ type_space: Option<(
+ &ast::Type,
+ ast::StateSpace,
+ )>,
+ is_dst,
+ _| {
+ let (typ, space) = match type_space {
+ Some((typ, space)) => (typ.clone(), space),
+ None => return Err(error_unreachable()),
+ };
+ arguments.push((operand, is_dst, typ, space));
+ Ok(SpirvWord(0))
+ })?;
+ let return_arguments_count = arguments
+ .iter()
+ .position(|(desc, is_dst, _, _)| !is_dst)
+ .unwrap_or(arguments.len());
+ let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
+ let fn_id = register_external_fn_call(
+ id_defs,
+ ptx_impl_imports,
+ fn_name,
+ return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ )?;
+ Ok(Statement::Instruction(ast::Instruction::Call {
+ data: ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ },
+ arguments: ast::CallArgs {
+ return_arguments: return_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ func: fn_id,
+ input_arguments: input_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ },
+ }))
+}
+
+fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
+ match this {
+ ast::ScalarType::B8 => "b8",
+ ast::ScalarType::B16 => "b16",
+ ast::ScalarType::B32 => "b32",
+ ast::ScalarType::B64 => "b64",
+ ast::ScalarType::B128 => "b128",
+ ast::ScalarType::U8 => "u8",
+ ast::ScalarType::U16 => "u16",
+ ast::ScalarType::U16x2 => "u16x2",
+ ast::ScalarType::U32 => "u32",
+ ast::ScalarType::U64 => "u64",
+ ast::ScalarType::S8 => "s8",
+ ast::ScalarType::S16 => "s16",
+ ast::ScalarType::S16x2 => "s16x2",
+ ast::ScalarType::S32 => "s32",
+ ast::ScalarType::S64 => "s64",
+ ast::ScalarType::F16 => "f16",
+ ast::ScalarType::F16x2 => "f16x2",
+ ast::ScalarType::F32 => "f32",
+ ast::ScalarType::F64 => "f64",
+ ast::ScalarType::BF16 => "bf16",
+ ast::ScalarType::BF16x2 => "bf16x2",
+ ast::ScalarType::Pred => "pred",
+ }
+}
+
+fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
+ match this {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcqRel => "acq_rel",
+ }
+}
+
+fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
+ match this {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ ast::MemScope::Cluster => "cluster",
+ }
+}
+
+fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
+ match this {
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
+ ast::StateSpace::SharedCluster => "shared_cluster",
+ ast::StateSpace::ParamEntry => "param_entry",
+ ast::StateSpace::SharedCta => "shared_cta",
+ ast::StateSpace::ParamFunc => "param_func",
+ }
+}
|