aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/pass/convert_to_stateful_memory_access.rs2
-rw-r--r--ptx/src/pass/deparamize_functions.rs141
-rw-r--r--ptx/src/pass/emit_llvm.rs224
-rw-r--r--ptx/src/pass/emit_spirv.rs3
-rw-r--r--ptx/src/pass/expand_arguments.rs4
-rw-r--r--ptx/src/pass/expand_operands.rs289
-rw-r--r--ptx/src/pass/extract_globals.rs1
-rw-r--r--ptx/src/pass/fix_special_registers2.rs209
-rw-r--r--ptx/src/pass/hoist_globals.rs45
-rw-r--r--ptx/src/pass/insert_explicit_load_store.rs338
-rw-r--r--ptx/src/pass/insert_implicit_conversions.rs26
-rw-r--r--ptx/src/pass/insert_implicit_conversions2.rs426
-rw-r--r--ptx/src/pass/insert_mem_ssa_statements.rs2
-rw-r--r--ptx/src/pass/mod.rs365
-rw-r--r--ptx/src/pass/normalize_identifiers2.rs199
-rw-r--r--ptx/src/pass/normalize_labels.rs1
-rw-r--r--ptx/src/pass/normalize_predicates2.rs84
-rw-r--r--ptx/src/pass/resolve_function_pointers.rs82
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
19 files changed, 2335 insertions, 108 deletions
diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs
index 455a8c2..3b8fa93 100644
--- a/ptx/src/pass/convert_to_stateful_memory_access.rs
+++ b/ptx/src/pass/convert_to_stateful_memory_access.rs
@@ -489,7 +489,7 @@ fn convert_to_stateful_memory_access_postprocess(
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
- let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
+ let kind = if new_operand_space == ast::StateSpace::Reg {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs
new file mode 100644
index 0000000..04c8831
--- /dev/null
+++ b/ptx/src/pass/deparamize_functions.rs
@@ -0,0 +1,141 @@
+use std::collections::BTreeMap;
+
+use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2,
+ mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ if method.func_decl.name.is_kernel() {
+ return Ok(method);
+ }
+ let is_declaration = method.body.is_none();
+ let mut body = Vec::new();
+ let mut remap_returns = Vec::new();
+ for arg in method.func_decl.return_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ remap_returns.push((old_name, arg.name, arg.v_type.clone()));
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ }
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
+ }
+ }
+ for arg in method.func_decl.input_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ body.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: arg.v_type.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: arg.name,
+ },
+ }));
+ }
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
+ }
+ }
+ if remap_returns.is_empty() {
+ return Ok(method);
+ }
+ let body = method
+ .body
+ .map(|statements| {
+ for statement in statements {
+ run_statement(&remap_returns, &mut body, statement)?;
+ }
+ Ok::<_, TranslateError>(body)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
+ result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+ statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<(), TranslateError> {
+ match statement {
+ Statement::Instruction(ast::Instruction::Ret { .. }) => {
+ for (old_name, new_name, type_) in remap_returns.iter().cloned() {
+ result.push(Statement::Instruction(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Reg,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_,
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: new_name,
+ src: old_name,
+ },
+ }));
+ }
+ result.push(statement);
+ }
+ statement => {
+ result.push(statement);
+ }
+ }
+ Ok(())
+}
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 44debba..235ad7d 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
}
pub(super) fn run<'input>(
- id_defs: &GlobalStringIdResolver<'input>,
- call_map: MethodsCallMap<'input>,
- directives: Vec<Directive<'input>>,
+ id_defs: GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
- let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
+ let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive::Variable(..) => todo!(),
- Directive::Method(method) => emit_ctx.emit_method(method)?,
+ Directive2::Variable(..) => todo!(),
+ Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
module.write_to_stderr();
@@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
@@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &Context,
module: &Module,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
@@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCallConv::LLVMCCallConv as u32
}
- fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
- let func_decl = method.func_decl.borrow();
+ fn emit_method(
+ &mut self,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
- .unwrap_or_else(|| match func_decl.name {
- ast::MethodName::Kernel(name) => name,
- ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
- });
+ .or_else(|| match func_decl.name {
+ ast::MethodName::Kernel(name) => Some(name),
+ ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
+ })
+ .ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
- let fn_type = self.function_type(
+ let fn_type = get_function_type(
+ self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl.input_arguments.iter().map(|v| &v.v_type),
- );
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
+ )?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ if let ast::MethodName::Func(name) = func_decl.name {
+ self.resolver.register(name, fn_);
+ }
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
+ if func_decl.name.is_kernel() {
+ let attr_kind = unsafe {
+ LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
+ };
+ let attr = unsafe {
+ LLVMCreateTypeAttribute(
+ self.context,
+ attr_kind,
+ get_type(self.context, &param.v_type)?,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
+ }
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
@@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
Ok(())
}
+}
- fn function_type(
- &self,
- return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- ) -> LLVMTypeRef {
- if return_args.len() == 0 {
- let mut input_args = input_args
- .map(|type_| match type_ {
- ast::Type::Scalar(scalar) => match scalar {
- ast::ScalarType::Pred => {
- unsafe { LLVMInt1TypeInContext(self.context) }
- }
- ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
- unsafe { LLVMInt8TypeInContext(self.context) }
- }
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
- unsafe { LLVMInt16TypeInContext(self.context) }
- }
- ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
- unsafe { LLVMInt32TypeInContext(self.context) }
- }
- ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
- unsafe { LLVMInt64TypeInContext(self.context) }
- }
- ast::ScalarType::B128 => {
- unsafe { LLVMInt128TypeInContext(self.context) }
- }
- ast::ScalarType::F16 => {
- unsafe { LLVMHalfTypeInContext(self.context) }
- }
- ast::ScalarType::F32 => {
- unsafe { LLVMFloatTypeInContext(self.context) }
- }
- ast::ScalarType::F64 => {
- unsafe { LLVMDoubleTypeInContext(self.context) }
- }
- ast::ScalarType::BF16 => {
- unsafe { LLVMBFloatTypeInContext(self.context) }
- }
- ast::ScalarType::U16x2 => todo!(),
- ast::ScalarType::S16x2 => todo!(),
- ast::ScalarType::F16x2 => todo!(),
- ast::ScalarType::BF16x2 => todo!(),
- },
- ast::Type::Vector(_, _) => todo!(),
- ast::Type::Array(_, _, _) => todo!(),
- ast::Type::Pointer(_, _) => todo!(),
- })
- .collect::<Vec<_>>();
- return unsafe {
- LLVMFunctionType(
- LLVMVoidTypeInContext(self.context),
- input_args.as_mut_ptr(),
- input_args.len() as u32,
- 0,
- )
- };
+fn get_input_argument_type(
+ context: LLVMContextRef,
+ v_type: &ptx_parser::Type,
+ state_space: ptx_parser::StateSpace,
+) -> Result<LLVMTypeRef, TranslateError> {
+ match state_space {
+ ptx_parser::StateSpace::ParamEntry => {
+ Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
- todo!()
+ ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ _ => return Err(error_unreachable()),
}
}
@@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> {
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
@@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
Statement::PtrAccess(_) => todo!(),
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
+ Statement::VectorAccess(_) => todo!(),
})
}
@@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
- ast::Instruction::Mov { data, arguments } => todo!(),
+ ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Or { data, arguments } => todo!(),
ast::Instruction::And { data, arguments } => todo!(),
ast::Instruction::Bra { arguments } => todo!(),
- ast::Instruction::Call { data, arguments } => todo!(),
+ ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(),
@@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
+
+ fn emit_call(
+ &mut self,
+ data: ptx_parser::CallDetails,
+ arguments: ptx_parser::CallArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if cfg!(debug_assertions) {
+ for (_, space) in data.return_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ for (_, space) in data.input_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ }
+ let name = match (&*data.return_arguments, &*arguments.return_arguments) {
+ ([], []) => LLVM_UNNAMED.as_ptr(),
+ ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
+ _ => todo!(),
+ };
+ let type_ = get_function_type(
+ self.context,
+ data.return_arguments.iter().map(|(type_, space)| type_),
+ data.input_arguments
+ .iter()
+ .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
+ )?;
+ let mut input_arguments = arguments
+ .input_arguments
+ .iter()
+ .map(|arg| self.resolver.value(*arg))
+ .collect::<Result<Vec<_>, _>>()?;
+ let llvm_fn = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ type_,
+ self.resolver.value(arguments.func)?,
+ input_arguments.as_mut_ptr(),
+ input_arguments.len() as u32,
+ name,
+ )
+ };
+ match &*arguments.return_arguments {
+ [] => {}
+ [name] => {
+ self.resolver.register(*name, llvm_fn);
+ }
+ _ => todo!(),
+ }
+ Ok(())
+ }
+
+ fn emit_mov(
+ &mut self,
+ _data: ptx_parser::MovDetails,
+ arguments: ptx_parser::MovArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.resolver
+ .register(arguments.dst, self.resolver.value(arguments.src)?);
+ Ok(())
+ }
}
fn get_pointer_type<'ctx>(
@@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
}
}
+fn get_function_type<'a>(
+ context: LLVMContextRef,
+ mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+ input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
+) -> Result<LLVMTypeRef, TranslateError> {
+ let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
+ input_args.collect::<Result<Vec<_>, _>>()?;
+ let return_type = match return_args.len() {
+ 0 => unsafe { LLVMVoidTypeInContext(context) },
+ 1 => get_type(context, return_args.next().unwrap())?,
+ _ => todo!(),
+ };
+ Ok(unsafe {
+ LLVMFunctionType(
+ return_type,
+ input_args.as_mut_ptr(),
+ input_args.len() as u32,
+ 0,
+ )
+ })
+}
+
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
- ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
- ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
+ ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@@ -647,7 +709,7 @@ struct ResolveIdent {
}
impl ResolveIdent {
- fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
+ fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),
diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs
index 5147b79..120a477 100644
--- a/ptx/src/pass/emit_spirv.rs
+++ b/ptx/src/pass/emit_spirv.rs
@@ -469,7 +469,6 @@ fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
ast::StateSpace::Param => spirv::StorageClass::Function,
ast::StateSpace::Reg => spirv::StorageClass::Function,
- ast::StateSpace::Sreg => spirv::StorageClass::Input,
ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::SharedCluster
@@ -693,7 +692,6 @@ fn emit_variable<'input>(
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
ast::StateSpace::Generic => todo!(),
- ast::StateSpace::Sreg => todo!(),
ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::SharedCluster
@@ -1563,6 +1561,7 @@ fn emit_function_body_ops<'input>(
builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
}
}
+ Statement::VectorAccess(vector_access) => todo!(),
}
}
Ok(())
diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs
index d0c7c98..e496c75 100644
--- a/ptx/src/pass/expand_arguments.rs
+++ b/ptx/src/pass/expand_arguments.rs
@@ -63,9 +63,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
} else {
return Err(TranslateError::UntypedSymbol);
};
- if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
+ if state_space == ast::StateSpace::Reg {
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
- if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
+ if reg_space != ast::StateSpace::Reg {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs
new file mode 100644
index 0000000..3dabf40
--- /dev/null
+++ b/ptx/src/pass/expand_operands.rs
@@ -0,0 +1,289 @@
+use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+ >,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ method: Function2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+ >,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(resolver, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+ statement: UnconditionalStatement,
+) -> Result<(), TranslateError> {
+ let mut visitor = FlattenArguments::new(resolver, result);
+ let new_statement = statement.visit_map(&mut visitor)?;
+ visitor.result.push(new_statement);
+ Ok(())
+}
+
+struct FlattenArguments<'a, 'input> {
+ result: &'a mut Vec<ExpandedStatement>,
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ post_stmts: Vec<ExpandedStatement>,
+}
+
+impl<'a, 'input> FlattenArguments<'a, 'input> {
+ fn new(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ result: &'a mut Vec<ExpandedStatement>,
+ ) -> Self {
+ FlattenArguments {
+ result,
+ resolver,
+ post_stmts: Vec::new(),
+ }
+ }
+
+ fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
+ Ok(name)
+ }
+
+ fn reg_offset(
+ &mut self,
+ reg: SpirvWord,
+ offset: i32,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (type_, state_space) = if let Some((type_, state_space)) = type_space {
+ (type_, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ if state_space == ast::StateSpace::Reg {
+ let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
+ if *reg_space != ast::StateSpace::Reg {
+ return Err(error_mismatched_type());
+ }
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => *underlying_type,
+ _ => return Err(error_mismatched_type()),
+ };
+ let reg_type = reg_type.clone();
+ let id_constant_stmt = self
+ .resolver
+ .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ })
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self
+ .resolver
+ .register_unnamed(Some((reg_type, state_space)));
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Add {
+ data: arith_details,
+ arguments: ast::AddArgs {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ }));
+ Ok(id_add_result)
+ } else {
+ let id_constant_stmt = self.resolver.register_unnamed(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::S64,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let dst = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), state_space)));
+ self.result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: type_.clone(),
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
+ }
+ }
+
+ fn immediate(
+ &mut self,
+ value: ast::ImmediateValue,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) =
+ if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
+ (*scalar, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ let id = self
+ .resolver
+ .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value,
+ }));
+ Ok(id)
+ }
+
+ fn vec_member(
+ &mut self,
+ vector_src: SpirvWord,
+ member: u8,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
+ (ast::Type::Vector(vector_width, scalar_t), space) => {
+ (*vector_width, *scalar_t, *space)
+ }
+ _ => return Err(error_mismatched_type()),
+ };
+ let temporary = self
+ .resolver
+ .register_unnamed(Some((scalar_type.into(), space)));
+ self.result.push(Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst: temporary,
+ src: vector_src,
+ member: member,
+ }));
+ Ok(temporary)
+ }
+
+ fn vec_pack(
+ &mut self,
+ vecs: Vec<SpirvWord>,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) = match type_space {
+ Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
+ _ => return Err(error_mismatched_type()),
+ };
+ let temp_vec = self
+ .resolver
+ .register_unnamed(Some((scalar_t.into(), state_space)));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: vecs,
+ relaxed_type_check,
+ });
+ if is_dst {
+ self.post_stmts.push(statement);
+ } else {
+ self.result.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
+ for FlattenArguments<'a, 'b>
+{
+ fn visit(
+ &mut self,
+ args: ast::ParsedOperand<SpirvWord>,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match args {
+ ast::ParsedOperand::Reg(r) => self.reg(r),
+ ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
+ ast::ParsedOperand::RegOffset(reg, offset) => {
+ self.reg_offset(reg, offset, type_space, is_dst)
+ }
+ ast::ParsedOperand::VecMember(vec, member) => {
+ self.vec_member(vec, member, type_space, is_dst)
+ }
+ ast::ParsedOperand::VecPack(vecs) => {
+ self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
+ }
+ }
+ }
+
+ fn visit_ident(
+ &mut self,
+ name: <TypedOperand as ast::Operand>::Ident,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
+ self.reg(name)
+ }
+}
+
+impl Drop for FlattenArguments<'_, '_> {
+ fn drop(&mut self) {
+ self.result.extend(self.post_stmts.drain(..));
+ }
+}
diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs
index 680a5ee..2912366 100644
--- a/ptx/src/pass/extract_globals.rs
+++ b/ptx/src/pass/extract_globals.rs
@@ -273,7 +273,6 @@ fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
- ast::StateSpace::Sreg => "sreg",
ast::StateSpace::SharedCluster => "shared_cluster",
ast::StateSpace::ParamEntry => "param_entry",
ast::StateSpace::SharedCta => "shared_cta",
diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs
new file mode 100644
index 0000000..97f6356
--- /dev/null
+++ b/ptx/src/pass/fix_special_registers2.rs
@@ -0,0 +1,209 @@
+use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ special_registers: &'a SpecialRegistersMap2,
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ let declarations = SpecialRegistersMap2::generate_declarations(resolver);
+ let mut result = Vec::with_capacity(declarations.len() + directives.len());
+ let mut sreg_to_function =
+ FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
+ for (sreg, declaration) in declarations {
+ let name = if let ast::MethodName::Func(name) = declaration.name {
+ name
+ } else {
+ return Err(error_unreachable());
+ };
+ result.push(UnconditionalDirective::Method(UnconditionalFunction {
+ func_decl: declaration,
+ globals: Vec::new(),
+ body: None,
+ import_as: None,
+ tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
+ }));
+ sreg_to_function.insert(sreg, name);
+ }
+ let mut visitor = SpecialRegisterResolver {
+ resolver,
+ special_registers,
+ sreg_to_function,
+ result: Vec::new(),
+ };
+ directives
+ .into_iter()
+ .map(|directive| run_directive(&mut visitor, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ directive: UnconditionalDirective<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
+ })
+}
+
+fn run_method<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ method: UnconditionalFunction<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(visitor, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ result: &mut Vec<UnconditionalStatement>,
+ statement: UnconditionalStatement,
+) -> Result<(), TranslateError> {
+ let converted_statement = statement.visit_map(visitor)?;
+ result.extend(visitor.result.drain(..));
+ result.push(converted_statement);
+ Ok(())
+}
+
+struct SpecialRegisterResolver<'a, 'input> {
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ special_registers: &'a SpecialRegistersMap2,
+ sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ result: Vec<UnconditionalStatement>,
+}
+
+impl<'a, 'b, 'input>
+ ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
+ for SpecialRegisterResolver<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: ast::ParsedOperand<SpirvWord>,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
+ map_operand(operand, &mut |ident, vector_index| {
+ self.replace_sreg(ident, vector_index, is_dst)
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.replace_sreg(args, None, is_dst)
+ }
+}
+
+impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
+ fn replace_sreg(
+ &mut self,
+ name: SpirvWord,
+ vector_index: Option<u8>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if let Some(sreg) = self.special_registers.get(name) {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::Unreachable);
+ }
+ let constant = self.resolver.register_unnamed(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(error_mismatched_type()),
+ };
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self
+ .resolver
+ .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let data = ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ };
+ let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
+ return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
+ func: self.sreg_to_function[&sreg],
+ input_arguments: input_arguments
+ .iter()
+ .map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
+ .collect(),
+ };
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(name)
+ }
+ }
+}
+
+pub fn map_operand<T, U, Err>(
+ this: ast::ParsedOperand<T>,
+ fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
+) -> Result<ast::ParsedOperand<U>, Err> {
+ Ok(match this {
+ ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
+ ast::ParsedOperand::RegOffset(ident, offset) => {
+ ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
+ }
+ ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
+ ast::ParsedOperand::VecMember(ident, member) => {
+ ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
+ }
+ ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
+ idents
+ .into_iter()
+ .map(|ident| fn_(ident, None))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
+}
diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs
new file mode 100644
index 0000000..753172a
--- /dev/null
+++ b/ptx/src/pass/hoist_globals.rs
@@ -0,0 +1,45 @@
+use super::*;
+
+pub(super) fn run<'input>(
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ let mut result = Vec::with_capacity(directives.len());
+ for mut directive in directives.into_iter() {
+ run_directive(&mut result, &mut directive);
+ result.push(directive);
+ }
+ Ok(result)
+}
+
+fn run_directive<'input>(
+ result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
+ directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<(), TranslateError> {
+ match directive {
+ Directive2::Variable(..) => {}
+ Directive2::Method(function2) => run_function(result, function2),
+ }
+ Ok(())
+}
+
+fn run_function<'input>(
+ result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
+ function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
+) {
+ function.body = function.body.take().map(|statements| {
+ statements
+ .into_iter()
+ .filter_map(|statement| match statement {
+ Statement::Variable(var @ ast::Variable {
+ state_space:
+ ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
+ ..
+ }) => {
+ result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
+ None
+ }
+ s => Some(s),
+ })
+ .collect()
+ });
+}
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs
new file mode 100644
index 0000000..ec6498c
--- /dev/null
+++ b/ptx/src/pass/insert_explicit_load_store.rs
@@ -0,0 +1,338 @@
+use super::*;
+use ptx_parser::VisitorMap;
+use rustc_hash::FxHashSet;
+
+// This pass:
+// * Turns all .local, .param and .reg in-body variables into .local variables
+// (if _not_ an input method argument)
+// * Inserts explicit `ld`/`st` for newly converted .reg variables
+// * Fixup state space of all existing `ld`/`st` instructions into newly
+// converted variables
+// * Turns `.entry` input arguments into param::entry and all related `.param`
+// loads into `param::entry` loads
+// * All `.func` input arguments are turned into `.reg` arguments by another
+// pass, so we do nothing there
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => {
+ let visitor = InsertMemSSAVisitor::new(resolver);
+ Directive2::Method(run_method(visitor, method)?)
+ }
+ })
+}
+
+fn run_method<'a, 'input>(
+ mut visitor: InsertMemSSAVisitor<'a, 'input>,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ let mut func_decl = method.func_decl;
+ for arg in func_decl.return_arguments.iter_mut() {
+ visitor.visit_variable(arg)?;
+ }
+ let is_kernel = func_decl.name.is_kernel();
+ if is_kernel {
+ for arg in func_decl.input_arguments.iter_mut() {
+ let old_name = arg.name;
+ let old_space = arg.state_space;
+ let new_space = ast::StateSpace::ParamEntry;
+ let new_name = visitor
+ .resolver
+ .register_unnamed(Some((arg.v_type.clone(), new_space)));
+ visitor.input_argument(old_name, new_name, old_space);
+ arg.name = new_name;
+ arg.state_space = new_space;
+ }
+ };
+ let body = method
+ .body
+ .map(move |statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(&mut visitor, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'a, 'input>(
+ visitor: &mut InsertMemSSAVisitor<'a, 'input>,
+ result: &mut Vec<ExpandedStatement>,
+ statement: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ match statement {
+ Statement::Variable(mut var) => {
+ visitor.visit_variable(&mut var)?;
+ result.push(Statement::Variable(var));
+ }
+ Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
+ let instruction = visitor.visit_ld(data, arguments)?;
+ let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
+ Statement::Instruction(ast::Instruction::St { data, arguments }) => {
+ let instruction = visitor.visit_st(data, arguments)?;
+ let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
+ s => {
+ let new_statement = s.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(new_statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
+ }
+ Ok(())
+}
+
+struct InsertMemSSAVisitor<'a, 'input> {
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ variables: FxHashMap<SpirvWord, RemapAction>,
+ pre: Vec<ast::Instruction<SpirvWord>>,
+ post: Vec<ast::Instruction<SpirvWord>>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ resolver,
+ variables: FxHashMap::default(),
+ pre: Vec::new(),
+ post: Vec::new(),
+ }
+ }
+
+ fn input_argument(
+ &mut self,
+ old_name: SpirvWord,
+ new_name: SpirvWord,
+ old_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ if old_space != ast::StateSpace::Param {
+ return Err(error_unreachable());
+ }
+ self.variables.insert(
+ old_name,
+ RemapAction::LDStSpaceChange {
+ name: new_name,
+ old_space,
+ new_space: ast::StateSpace::ParamEntry,
+ },
+ );
+ Ok(())
+ }
+
+ fn variable(
+ &mut self,
+ type_: &ast::Type,
+ old_name: SpirvWord,
+ new_name: SpirvWord,
+ old_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ Ok(match old_space {
+ ast::StateSpace::Reg => {
+ self.variables.insert(
+ old_name,
+ RemapAction::PreLdPostSt {
+ name: new_name,
+ type_: type_.clone(),
+ },
+ );
+ }
+ ast::StateSpace::Param => {
+ self.variables.insert(
+ old_name,
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space: ast::StateSpace::Local,
+ name: new_name,
+ },
+ );
+ }
+ // Good as-is
+ ast::StateSpace::Local => {}
+ // Will be pulled into global scope later
+ ast::StateSpace::Generic
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::Shared => {}
+ ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
+ return Err(error_unreachable())
+ }
+ })
+ }
+
+ fn visit_st(
+ &self,
+ mut data: ast::StData,
+ mut arguments: ast::StArgs<SpirvWord>,
+ ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
+ if let Some(remap) = self.variables.get(&arguments.src1) {
+ match remap {
+ RemapAction::PreLdPostSt { .. } => {}
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ } => {
+ if data.state_space != *old_space {
+ return Err(error_mismatched_type());
+ }
+ data.state_space = *new_space;
+ arguments.src1 = *name;
+ }
+ }
+ }
+ Ok(ast::Instruction::St { data, arguments })
+ }
+
+ fn visit_ld(
+ &self,
+ mut data: ast::LdDetails,
+ mut arguments: ast::LdArgs<SpirvWord>,
+ ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
+ if let Some(remap) = self.variables.get(&arguments.src) {
+ match remap {
+ RemapAction::PreLdPostSt { .. } => {}
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ } => {
+ if data.state_space != *old_space {
+ return Err(error_mismatched_type());
+ }
+ data.state_space = *new_space;
+ arguments.src = *name;
+ }
+ }
+ }
+ Ok(ast::Instruction::Ld { data, arguments })
+ }
+
+ fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
+ if var.state_space != ast::StateSpace::Local {
+ let old_name = var.name;
+ let old_space = var.state_space;
+ let new_space = ast::StateSpace::Local;
+ let new_name = self
+ .resolver
+ .register_unnamed(Some((var.v_type.clone(), new_space)));
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
+ var.name = new_name;
+ var.state_space = new_space;
+ }
+ Ok(())
+ }
+}
+
+impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ ident: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if let Some(remap) = self.variables.get(&ident) {
+ match remap {
+ RemapAction::PreLdPostSt { name, type_ } => {
+ if is_dst {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.post.push(ast::Instruction::St {
+ data: ast::StData {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: type_.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: *name,
+ src2: temp,
+ },
+ });
+ Ok(temp)
+ } else {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.pre.push(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: temp,
+ src: *name,
+ },
+ });
+ Ok(temp)
+ }
+ }
+ RemapAction::LDStSpaceChange { .. } => {
+ return Err(error_mismatched_type());
+ }
+ }
+ } else {
+ Ok(ident)
+ }
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.visit(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+#[derive(Clone)]
+enum RemapAction {
+ PreLdPostSt {
+ name: SpirvWord,
+ type_: ast::Type,
+ },
+ LDStSpaceChange {
+ old_space: ast::StateSpace,
+ new_space: ast::StateSpace,
+ name: SpirvWord,
+ },
+}
diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs
index 25e80f0..c04fa09 100644
--- a/ptx/src/pass/insert_implicit_conversions.rs
+++ b/ptx/src/pass/insert_implicit_conversions.rs
@@ -45,6 +45,13 @@ pub(super) fn run(
Statement::RepackVector(repack),
)?;
}
+ Statement::VectorAccess(vector_access) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::VectorAccess(vector_access),
+ )?;
+ }
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
@@ -128,7 +135,7 @@ pub(crate) fn default_implicit_conversion(
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
- if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ if operand_space == ast::StateSpace::Reg {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
@@ -142,7 +149,7 @@ pub(crate) fn default_implicit_conversion(
return Ok(Some(ConversionKind::AddressOf));
}
}
- if !space_is_compatible(instruction_space, operand_space) {
+ if instruction_space != operand_space {
default_implicit_conversion_space(
(operand_space, operand_type),
(instruction_space, instruction_type),
@@ -161,7 +168,7 @@ fn is_addressable(this: ast::StateSpace) -> bool {
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
- ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
+ ast::StateSpace::Param | ast::StateSpace::Reg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
@@ -178,7 +185,7 @@ fn default_implicit_conversion_space(
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
- } else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ } else if operand_space == ast::StateSpace::Reg {
match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space =>
@@ -210,7 +217,7 @@ fn default_implicit_conversion_space(
},
_ => Err(error_mismatched_type()),
}
- } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
+ } else if instruction_space == ast::StateSpace::Reg {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
@@ -234,7 +241,7 @@ fn default_implicit_conversion_type(
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
- if space_is_compatible(space, ast::StateSpace::Reg) {
+ if space == ast::StateSpace::Reg {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
@@ -257,8 +264,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool {
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
- | ast::StateSpace::Generic
- | ast::StateSpace::Sreg => false,
+ | ast::StateSpace::Generic => false,
}
}
@@ -294,7 +300,7 @@ pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if !space_is_compatible(operand_space, instruction_space) {
+ if operand_space != instruction_space {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
@@ -371,7 +377,7 @@ pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if !space_is_compatible(operand_space, instruction_space) {
+ if operand_space != instruction_space {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs
new file mode 100644
index 0000000..4f738f5
--- /dev/null
+++ b/ptx/src/pass/insert_implicit_conversions2.rs
@@ -0,0 +1,426 @@
+use std::mem;
+
+use super::*;
+use ptx_parser as ast;
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+*/
+pub(super) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(mut method) => {
+ method.body = method
+ .body
+ .map(|statements| run_statements(resolver, statements))
+ .transpose()?;
+ Directive2::Method(method)
+ }
+ })
+}
+
+fn run_statements<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func.into_iter() {
+ insert_implicit_conversions_impl(resolver, &mut result, s)?;
+ }
+ Ok(result)
+}
+
+fn insert_implicit_conversions_impl<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: &mut Vec<ExpandedStatement>,
+ stmt: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_map::<SpirvWord, TranslateError>(
+ &mut |operand,
+ type_state: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst,
+ relaxed_type_check| {
+ let (instr_type, instruction_space) = match type_state {
+ None => return Ok(operand),
+ Some(t) => t,
+ };
+ let (operand_type, operand_space) = resolver.get_typed(operand)?;
+ let conversion_fn = if relaxed_type_check {
+ if is_dst {
+ should_convert_relaxed_dst_wrapper
+ } else {
+ should_convert_relaxed_src_wrapper
+ }
+ } else {
+ default_implicit_conversion
+ };
+ match conversion_fn(
+ (*operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
+ Some(conv_kind) => {
+ let conv_output = if is_dst { &mut post_conv } else { &mut *func };
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type.clone();
+ let mut to_space = *operand_space;
+ let mut src =
+ resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
+ let mut dst = operand;
+ let result = Ok::<_, TranslateError>(src);
+ if !is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(operand),
+ }
+ },
+ )?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
+pub(crate) fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instruction_space == ast::StateSpace::Reg {
+ if operand_space == ast::StateSpace::Reg {
+ if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
+ {
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ } else if is_addressable(operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ }
+ if instruction_space != operand_space {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
+}
+
+fn is_addressable(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg => false,
+ ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => todo!(),
+ }
+}
+
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
+ || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space == ast::StateSpace::Reg {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(error_mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(error_mismatched_type()),
+ },
+ _ => Err(error_mismatched_type()),
+ }
+ } else if instruction_space == ast::StateSpace::Reg {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(error_mismatched_type()),
+ }
+ } else {
+ Err(error_mismatched_type())
+ }
+}
+
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if space == ast::StateSpace::Reg {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
+ }
+}
+
+fn coerces_to_generic(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ptx_parser::StateSpace::SharedCta
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::Generic => false,
+ }
+}
+
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.size_of() != operand.size_of() {
+ return false;
+ }
+ match inst.kind() {
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ }
+ ast::ScalarKind::Pred => false,
+ }
+ }
+ (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
+ | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
+ }
+ _ => false,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_dst_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if dst_type == instr_type {
+ return None;
+ }
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
+ if instr_type.size_of() == dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else if instr_type.size_of() < dst_type.size_of() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_src_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(error_mismatched_type());
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(error_mismatched_type()),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if src_type == instr_type {
+ return None;
+ }
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= src_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs
index e314b05..150109b 100644
--- a/ptx/src/pass/insert_mem_ssa_statements.rs
+++ b/ptx/src/pass/insert_mem_ssa_statements.rs
@@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
return Ok(symbol);
};
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
- if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
+ if var_space != ast::StateSpace::Reg || !is_variable {
return Ok(symbol);
};
let member_index = match member_index {
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
index 3aa3b0a..0e233ed 100644
--- a/ptx/src/pass/mod.rs
+++ b/ptx/src/pass/mod.rs
@@ -1,5 +1,6 @@
use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
+use rustc_hash::FxHashMap;
use std::hash::Hash;
use std::num::NonZeroU8;
use std::{
@@ -12,20 +13,31 @@ use std::{
mem,
rc::Rc,
};
+use strum::IntoEnumIterator;
+use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
+mod deparamize_functions;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
+mod expand_operands;
mod extract_globals;
mod fix_special_registers;
+mod fix_special_registers2;
+mod hoist_globals;
+mod insert_explicit_load_store;
mod insert_implicit_conversions;
+mod insert_implicit_conversions2;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
+mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
+mod normalize_predicates2;
+mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
})?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
- let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
+ todo!()
+ /*
+ let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
+ Ok(Module {
+ llvm_ir,
+ kernel_info: HashMap::new(),
+ }) */
+}
+
+pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
+ let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
+ let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
+ let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
+ let directives = resolve_function_pointers::run(directives)?;
+ let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
+ expand_operands::run(&mut flat_resolver, directives)?;
+ let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
+ let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
+ let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
+ let directives = hoist_globals::run(directives)?;
+ let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
@@ -319,7 +354,7 @@ pub struct KernelInfo {
pub uses_shared_mem: bool,
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister {
Tid,
Ntid,
@@ -342,6 +377,17 @@ impl PtxSpecialRegister {
}
}
+ fn as_str(self) -> &'static str {
+ match self {
+ Self::Tid => "%tid",
+ Self::Ntid => "%ntid",
+ Self::Ctaid => "%ctaid",
+ Self::Nctaid => "%nctaid",
+ Self::Clock => "%clock",
+ Self::LanemaskLt => "%lanemask_lt",
+ }
+ }
+
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
@@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -722,6 +768,7 @@ enum Statement<I, P: ast::Operand> {
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
+ VectorAccess(VectorAccess),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@@ -890,6 +937,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
offset_src,
})
}
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src: vector_src,
+ member,
+ }) => {
+ let dst: SpirvWord = visitor.visit_ident(
+ dst,
+ Some((&scalar_type.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ vector_src,
+ Some((
+ &ast::Type::Vector(vector_width, scalar_type),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src,
+ member,
+ })
+ }
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
@@ -1207,12 +1284,6 @@ impl<
}
}
-fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
- this == other
- || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
- || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
-}
-
fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
@@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>(
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
}
}
}
+
+pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
+ Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
+ Method(Function2<'input, Instruction, Operand>),
+}
+
+pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
+ pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
+ pub globals: Vec<ast::Variable<SpirvWord>>,
+ pub body: Option<Vec<Statement<Instruction, Operand>>>,
+ import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+}
+
+type NormalizedDirective2<'input> = Directive2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type NormalizedFunction2<'input> = Function2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalDirective<'input> = Directive2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalFunction<'input> = Function2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+struct GlobalStringIdentResolver2<'input> {
+ pub(crate) current_id: SpirvWord,
+ pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+}
+
+impl<'input> GlobalStringIdentResolver2<'input> {
+ fn new(spirv_word: SpirvWord) -> Self {
+ Self {
+ current_id: spirv_word,
+ ident_map: FxHashMap::default(),
+ }
+ }
+
+ fn register_named(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: None,
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
+ match self.ident_map.get(&id) {
+ Some(IdentEntry {
+ type_space: Some(type_space),
+ ..
+ }) => Ok(type_space),
+ _ => Err(error_unknown_symbol()),
+ }
+ }
+}
+
+struct IdentEntry<'input> {
+ name: Option<Cow<'input, str>>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+}
+
+struct ScopedResolver<'input, 'b> {
+ flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
+ scopes: Vec<ScopeMarker<'input>>,
+}
+
+impl<'input, 'b> ScopedResolver<'input, 'b> {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ flat_resolver,
+ scopes: vec![ScopeMarker::new()],
+ }
+ }
+
+ fn start_scope(&mut self) {
+ self.scopes.push(ScopeMarker::new());
+ }
+
+ fn end_scope(&mut self) {
+ let scope = self.scopes.pop().unwrap();
+ scope.flush(self.flat_resolver);
+ }
+
+ fn add(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let result = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ let current_scope = self.scopes.last_mut().unwrap();
+ if current_scope
+ .name_to_ident
+ .insert(name.clone(), result)
+ .is_some()
+ {
+ return Err(error_unknown_symbol());
+ }
+ current_scope.ident_map.insert(
+ result,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ Ok(result)
+ }
+
+ fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
+ self.scopes
+ .iter()
+ .rev()
+ .find_map(|resolver| resolver.name_to_ident.get(name).copied())
+ .ok_or_else(|| error_unreachable())
+ }
+
+ fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last().unwrap();
+ current_scope
+ .name_to_ident
+ .get(label)
+ .copied()
+ .ok_or_else(|| error_unreachable())
+ }
+}
+
+struct ScopeMarker<'input> {
+ ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+ name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
+}
+
+impl<'input> ScopeMarker<'input> {
+ fn new() -> Self {
+ Self {
+ ident_map: FxHashMap::default(),
+ name_to_ident: FxHashMap::default(),
+ }
+ }
+
+ fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
+ resolver.ident_map.extend(self.ident_map);
+ }
+}
+
+struct SpecialRegistersMap2 {
+ reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap2 {
+ fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
+ let mut result = SpecialRegistersMap2 {
+ reg_to_id: FxHashMap::default(),
+ id_to_reg: FxHashMap::default(),
+ };
+ for sreg in PtxSpecialRegister::iter() {
+ let text = sreg.as_str();
+ let id = resolver.add(
+ Cow::Borrowed(text),
+ Some((sreg.get_type(), ast::StateSpace::Reg)),
+ )?;
+ result.reg_to_id.insert(sreg, id);
+ result.id_to_reg.insert(id, sreg);
+ }
+ Ok(result)
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+
+ fn generate_declarations<'a, 'input>(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ ) -> impl ExactSizeIterator<
+ Item = (
+ PtxSpecialRegister,
+ ast::MethodDeclaration<'input, SpirvWord>,
+ ),
+ > + 'a {
+ PtxSpecialRegister::iter().map(|sreg| {
+ let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let name =
+ ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
+ let return_type = sreg.get_function_return_type();
+ let input_type = sreg.get_function_return_type();
+ (
+ sreg,
+ ast::MethodDeclaration {
+ return_arguments: vec![ast::Variable {
+ align: None,
+ v_type: return_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ name: name,
+ input_arguments: vec![ast::Variable {
+ align: None,
+ v_type: input_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ shared_mem: None,
+ },
+ )
+ })
+ }
+}
+
+pub struct VectorAccess {
+ scalar_type: ast::ScalarType,
+ vector_width: u8,
+ dst: SpirvWord,
+ src: SpirvWord,
+ member: u8,
+}
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs
new file mode 100644
index 0000000..beaf08b
--- /dev/null
+++ b/ptx/src/pass/normalize_identifiers2.rs
@@ -0,0 +1,199 @@
+use super::*;
+use ptx_parser as ast;
+use rustc_hash::FxHashMap;
+
+pub(crate) fn run<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
+) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
+ resolver.start_scope();
+ let result = directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()?;
+ resolver.end_scope();
+ Ok(result)
+}
+
+fn run_directive<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
+) -> Result<NormalizedDirective2<'input>, TranslateError> {
+ Ok(match directive {
+ ast::Directive::Variable(linking, var) => {
+ NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
+ }
+ ast::Directive::Method(linking, directive) => {
+ NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
+ }
+ })
+}
+
+fn run_method<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ linkage: ast::LinkingDirective,
+ method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<NormalizedFunction2<'input>, TranslateError> {
+ let name = match method.func_directive.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(text) => {
+ ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
+ }
+ };
+ resolver.start_scope();
+ let func_decl = run_function_decl(resolver, method.func_directive, name)?;
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ run_statements(resolver, &mut result, statements)?;
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ resolver.end_scope();
+ Ok(Function2 {
+ func_decl,
+ globals: Vec::new(),
+ body,
+ import_as: None,
+ tuning: method.tuning,
+ linkage,
+ })
+}
+
+fn run_function_decl<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ func_directive: ast::MethodDeclaration<'input, &'input str>,
+ name: ast::MethodName<'input, SpirvWord>,
+) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
+ assert!(func_directive.shared_mem.is_none());
+ let return_arguments = func_directive
+ .return_arguments
+ .into_iter()
+ .map(|var| run_variable(resolver, var))
+ .collect::<Result<Vec<_>, _>>()?;
+ let input_arguments = func_directive
+ .input_arguments
+ .into_iter()
+ .map(|var| run_variable(resolver, var))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ })
+}
+
+fn run_variable<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ variable: ast::Variable<&'input str>,
+) -> Result<ast::Variable<SpirvWord>, TranslateError> {
+ Ok(ast::Variable {
+ name: resolver.add(
+ Cow::Borrowed(variable.name),
+ Some((variable.v_type.clone(), variable.state_space)),
+ )?,
+ align: variable.align,
+ v_type: variable.v_type,
+ state_space: variable.state_space,
+ array_init: variable.array_init,
+ })
+}
+
+fn run_statements<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<(), TranslateError> {
+ for statement in statements.iter() {
+ match statement {
+ ast::Statement::Label(label) => {
+ resolver.add(Cow::Borrowed(*label), None)?;
+ }
+ _ => {}
+ }
+ }
+ for statement in statements {
+ match statement {
+ ast::Statement::Label(label) => {
+ result.push(Statement::Label(resolver.get_in_current_scope(label)?))
+ }
+ ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
+ ast::Statement::Instruction(predicate, instruction) => {
+ result.push(Statement::Instruction((
+ predicate
+ .map(|pred| {
+ Ok::<_, TranslateError>(ast::PredAt {
+ not: pred.not,
+ label: resolver.get(pred.label)?,
+ })
+ })
+ .transpose()?,
+ run_instruction(resolver, instruction)?,
+ )))
+ }
+ ast::Statement::Block(block) => {
+ resolver.start_scope();
+ run_statements(resolver, result, block)?;
+ resolver.end_scope();
+ }
+ }
+ }
+ Ok(())
+}
+
+fn run_instruction<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
+) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
+ ast::visit_map(instruction, &mut |name: &'input str,
+ _: Option<(
+ &ast::Type,
+ ast::StateSpace,
+ )>,
+ _,
+ _| {
+ resolver.get(&name)
+ })
+}
+
+fn run_multivariable<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ variable: ast::MultiVariable<&'input str>,
+) -> Result<(), TranslateError> {
+ match variable.count {
+ Some(count) => {
+ for i in 0..count {
+ let name = Cow::Owned(format!("{}{}", variable.var.name, i));
+ let ident = resolver.add(
+ name,
+ Some((variable.var.v_type.clone(), variable.var.state_space)),
+ )?;
+ result.push(Statement::Variable(ast::Variable {
+ align: variable.var.align,
+ v_type: variable.var.v_type.clone(),
+ state_space: variable.var.state_space,
+ name: ident,
+ array_init: variable.var.array_init.clone(),
+ }));
+ }
+ }
+ None => {
+ let name = Cow::Borrowed(variable.var.name);
+ let ident = resolver.add(
+ name,
+ Some((variable.var.v_type.clone(), variable.var.state_space)),
+ )?;
+ result.push(Statement::Variable(ast::Variable {
+ align: variable.var.align,
+ v_type: variable.var.v_type.clone(),
+ state_space: variable.var.state_space,
+ name: ident,
+ array_init: variable.var.array_init.clone(),
+ }));
+ }
+ }
+ Ok(())
+}
diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs
index 097d87c..037e918 100644
--- a/ptx/src/pass/normalize_labels.rs
+++ b/ptx/src/pass/normalize_labels.rs
@@ -26,6 +26,7 @@ pub(super) fn run(
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
+ | Statement::VectorAccess { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}
diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs
new file mode 100644
index 0000000..d91e23c
--- /dev/null
+++ b/ptx/src/pass/normalize_predicates2.rs
@@ -0,0 +1,84 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<NormalizedDirective2<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: NormalizedDirective2<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ method: NormalizedFunction2<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(resolver, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ result: &mut Vec<UnconditionalStatement>,
+ statement: NormalizedStatement,
+) -> Result<(), TranslateError> {
+ Ok(match statement {
+ Statement::Label(label) => result.push(Statement::Label(label)),
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ Statement::Instruction((predicate, instruction)) => {
+ if let Some(pred) = predicate {
+ let if_true = resolver.register_unnamed(None);
+ let if_false = resolver.register_unnamed(None);
+ let folded_bra = match &instruction {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(instruction));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(instruction));
+ }
+ }
+ _ => return Err(error_unreachable()),
+ })
+}
diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs
new file mode 100644
index 0000000..eb7abb1
--- /dev/null
+++ b/ptx/src/pass/resolve_function_pointers.rs
@@ -0,0 +1,82 @@
+use super::*;
+use ptx_parser as ast;
+use rustc_hash::FxHashSet;
+
+pub(crate) fn run<'input>(
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ let mut functions = FxHashSet::default();
+ directives
+ .into_iter()
+ .map(|directive| run_directive(&mut functions, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ directive: UnconditionalDirective<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => {
+ {
+ let func_decl = &method.func_decl;
+ match func_decl.name {
+ ptx_parser::MethodName::Kernel(_) => {}
+ ptx_parser::MethodName::Func(name) => {
+ functions.insert(name);
+ }
+ }
+ }
+ Directive2::Method(run_method(functions, method)?)
+ }
+ })
+}
+
+fn run_method<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ method: UnconditionalFunction<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ statements
+ .into_iter()
+ .map(|statement| run_statement(functions, statement))
+ .collect::<Result<Vec<_>, _>>()
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ statement: UnconditionalStatement,
+) -> Result<UnconditionalStatement, TranslateError> {
+ Ok(match statement {
+ Statement::Instruction(ast::Instruction::Mov {
+ data,
+ arguments:
+ ast::MovArgs {
+ dst: ast::ParsedOperand::Reg(dst_reg),
+ src: ast::ParsedOperand::Reg(src_reg),
+ },
+ }) if functions.contains(&src_reg) => {
+ if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(error_mismatched_type());
+ }
+ UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ })
+ }
+ s => s,
+ })
+}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 69dd206..e15d6ea 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -236,7 +236,7 @@ fn test_hip_assert<
output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
- let llvm_ir = pass::to_llvm_module(ast).unwrap();
+ let llvm_ir = pass::to_llvm_module2(ast).unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;