aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-09-23 06:02:28 +0200
committerAndrzej Janik <[email protected]>2024-09-23 06:02:28 +0200
commit78a9f22cf7e6c819f04991c1624578c969c1a146 (patch)
tree89bab98e3071aedd12f755bfde8a7c7382138ed7
parent7bd4179d1dd24f81b56e66fd13c16631b518495f (diff)
downloadZLUDA-78a9f22cf7e6c819f04991c1624578c969c1a146.tar.gz
ZLUDA-78a9f22cf7e6c819f04991c1624578c969c1a146.zip
Refactor implicit conversions, explicit ld/st and global hoistingrepass
-rw-r--r--ptx/src/pass/emit_llvm.rs80
-rw-r--r--ptx/src/pass/hoist_globals.rs45
-rw-r--r--ptx/src/pass/insert_explicit_load_store.rs101
-rw-r--r--ptx/src/pass/insert_implicit_conversions2.rs426
-rw-r--r--ptx/src/pass/mod.rs19
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
6 files changed, 627 insertions, 46 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 3060335..235ad7d 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
}
pub(super) fn run<'input>(
- id_defs: &GlobalStringIdResolver<'input>,
- call_map: MethodsCallMap<'input>,
- directives: Vec<Directive<'input>>,
+ id_defs: GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
- let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
+ let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive::Variable(..) => todo!(),
- Directive::Method(method) => emit_ctx.emit_method(method)?,
+ Directive2::Variable(..) => todo!(),
+ Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
module.write_to_stderr();
@@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
@@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &Context,
module: &Module,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
@@ -215,20 +214,27 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCallConv::LLVMCCallConv as u32
}
- fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
- let func_decl = method.func_decl.borrow();
+ fn emit_method(
+ &mut self,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
- .unwrap_or_else(|| match func_decl.name {
- ast::MethodName::Kernel(name) => name,
- ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
- });
+ .or_else(|| match func_decl.name {
+ ast::MethodName::Kernel(name) => Some(name),
+ ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
+ })
+ .ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl.input_arguments.iter().map(|v| &v.v_type),
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
if let ast::MethodName::Func(name) = func_decl.name {
@@ -239,6 +245,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
+ if func_decl.name.is_kernel() {
+ let attr_kind = unsafe {
+ LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
+ };
+ let attr = unsafe {
+ LLVMCreateTypeAttribute(
+ self.context,
+ attr_kind,
+ get_type(self.context, &param.v_type)?,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
+ }
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
@@ -264,12 +283,26 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
}
+fn get_input_argument_type(
+ context: LLVMContextRef,
+ v_type: &ptx_parser::Type,
+ state_space: ptx_parser::StateSpace,
+) -> Result<LLVMTypeRef, TranslateError> {
+ match state_space {
+ ptx_parser::StateSpace::ParamEntry => {
+ Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
+ }
+ ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ _ => return Err(error_unreachable()),
+ }
+}
+
struct MethodEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
@@ -533,7 +566,9 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, space)| type_),
- data.input_arguments.iter().map(|(type_, space)| type_),
+ data.input_arguments
+ .iter()
+ .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
)?;
let mut input_arguments = arguments
.input_arguments
@@ -633,11 +668,10 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+ input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
- let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
- .map(|type_| get_type(context, type_))
- .collect::<Result<Vec<_>, _>>()?;
+ let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
+ input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
@@ -658,7 +692,7 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
- ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
+ ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@@ -675,7 +709,7 @@ struct ResolveIdent {
}
impl ResolveIdent {
- fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
+ fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),
diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs
new file mode 100644
index 0000000..753172a
--- /dev/null
+++ b/ptx/src/pass/hoist_globals.rs
@@ -0,0 +1,45 @@
+use super::*;
+
+pub(super) fn run<'input>(
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ let mut result = Vec::with_capacity(directives.len());
+ for mut directive in directives.into_iter() {
+ run_directive(&mut result, &mut directive);
+ result.push(directive);
+ }
+ Ok(result)
+}
+
+fn run_directive<'input>(
+ result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
+ directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<(), TranslateError> {
+ match directive {
+ Directive2::Variable(..) => {}
+ Directive2::Method(function2) => run_function(result, function2),
+ }
+ Ok(())
+}
+
+fn run_function<'input>(
+ result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
+ function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
+) {
+ function.body = function.body.take().map(|statements| {
+ statements
+ .into_iter()
+ .filter_map(|statement| match statement {
+ Statement::Variable(var @ ast::Variable {
+ state_space:
+ ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
+ ..
+ }) => {
+ result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
+ None
+ }
+ s => Some(s),
+ })
+ .collect()
+ });
+}
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs
index e8f01cd..ec6498c 100644
--- a/ptx/src/pass/insert_explicit_load_store.rs
+++ b/ptx/src/pass/insert_explicit_load_store.rs
@@ -41,10 +41,9 @@ fn run_method<'a, 'input>(
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
- visitor.visit_variable(arg);
+ visitor.visit_variable(arg)?;
}
let is_kernel = func_decl.name.is_kernel();
- // let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
let old_name = arg.name;
@@ -85,23 +84,29 @@ fn run_statement<'a, 'input>(
) -> Result<(), TranslateError> {
match statement {
Statement::Variable(mut var) => {
- visitor.visit_variable(&mut var);
+ visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
- Statement::Instruction(ast::Instruction::St {
- data,
- mut arguments,
- }) => {
+ Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
+ s => {
+ let new_statement = s.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(new_statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
- s => result.push(s.visit_map(visitor)?),
}
Ok(())
}
@@ -109,6 +114,8 @@ fn run_statement<'a, 'input>(
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
+ pre: Vec<ast::Instruction<SpirvWord>>,
+ post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
@@ -116,6 +123,8 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Self {
resolver,
variables: FxHashMap::default(),
+ pre: Vec::new(),
+ post: Vec::new(),
}
}
@@ -141,14 +150,20 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn variable(
&mut self,
+ type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
- self.variables
- .insert(old_name, RemapAction::PreLdPostSt(new_name));
+ self.variables.insert(
+ old_name,
+ RemapAction::PreLdPostSt {
+ name: new_name,
+ type_: type_.clone(),
+ },
+ );
}
ast::StateSpace::Param => {
self.variables.insert(
@@ -182,7 +197,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
- RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@@ -206,7 +221,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
- RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@@ -223,7 +238,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Ok(ast::Instruction::Ld { data, arguments })
}
- fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
+ fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
@@ -231,10 +246,11 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
- self.variable(old_name, new_name, old_space);
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
}
+ Ok(())
}
}
@@ -243,12 +259,58 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
{
fn visit(
&mut self,
- args: SpirvWord,
+ ident: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
- todo!()
+ if let Some(remap) = self.variables.get(&ident) {
+ match remap {
+ RemapAction::PreLdPostSt { name, type_ } => {
+ if is_dst {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.post.push(ast::Instruction::St {
+ data: ast::StData {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: type_.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: *name,
+ src2: temp,
+ },
+ });
+ Ok(temp)
+ } else {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.pre.push(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: temp,
+ src: *name,
+ },
+ });
+ Ok(temp)
+ }
+ }
+ RemapAction::LDStSpaceChange { .. } => {
+ return Err(error_mismatched_type());
+ }
+ }
+ } else {
+ Ok(ident)
+ }
}
fn visit_ident(
@@ -262,9 +324,12 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
}
}
-#[derive(Clone, Copy)]
+#[derive(Clone)]
enum RemapAction {
- PreLdPostSt(SpirvWord),
+ PreLdPostSt {
+ name: SpirvWord,
+ type_: ast::Type,
+ },
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,
diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs
new file mode 100644
index 0000000..4f738f5
--- /dev/null
+++ b/ptx/src/pass/insert_implicit_conversions2.rs
@@ -0,0 +1,426 @@
+use std::mem;
+
+use super::*;
+use ptx_parser as ast;
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+*/
+pub(super) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(mut method) => {
+ method.body = method
+ .body
+ .map(|statements| run_statements(resolver, statements))
+ .transpose()?;
+ Directive2::Method(method)
+ }
+ })
+}
+
+fn run_statements<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func.into_iter() {
+ insert_implicit_conversions_impl(resolver, &mut result, s)?;
+ }
+ Ok(result)
+}
+
+fn insert_implicit_conversions_impl<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: &mut Vec<ExpandedStatement>,
+ stmt: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_map::<SpirvWord, TranslateError>(
+ &mut |operand,
+ type_state: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst,
+ relaxed_type_check| {
+ let (instr_type, instruction_space) = match type_state {
+ None => return Ok(operand),
+ Some(t) => t,
+ };
+ let (operand_type, operand_space) = resolver.get_typed(operand)?;
+ let conversion_fn = if relaxed_type_check {
+ if is_dst {
+ should_convert_relaxed_dst_wrapper
+ } else {
+ should_convert_relaxed_src_wrapper
+ }
+ } else {
+ default_implicit_conversion
+ };
+ match conversion_fn(
+ (*operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
+ Some(conv_kind) => {
+ let conv_output = if is_dst { &mut post_conv } else { &mut *func };
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type.clone();
+ let mut to_space = *operand_space;
+ let mut src =
+ resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
+ let mut dst = operand;
+ let result = Ok::<_, TranslateError>(src);
+ if !is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(operand),
+ }
+ },
+ )?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
+pub(crate) fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instruction_space == ast::StateSpace::Reg {
+ if operand_space == ast::StateSpace::Reg {
+ if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
+ {
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ } else if is_addressable(operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ }
+ if instruction_space != operand_space {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
+}
+
+fn is_addressable(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg => false,
+ ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => todo!(),
+ }
+}
+
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
+ || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space == ast::StateSpace::Reg {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(error_mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(error_mismatched_type()),
+ },
+ _ => Err(error_mismatched_type()),
+ }
+ } else if instruction_space == ast::StateSpace::Reg {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(error_mismatched_type()),
+ }
+ } else {
+ Err(error_mismatched_type())
+ }
+}
+
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if space == ast::StateSpace::Reg {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
+ }
+}
+
+fn coerces_to_generic(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ptx_parser::StateSpace::SharedCta
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::Generic => false,
+ }
+}
+
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.size_of() != operand.size_of() {
+ return false;
+ }
+ match inst.kind() {
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ }
+ ast::ScalarKind::Pred => false,
+ }
+ }
+ (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
+ | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
+ }
+ _ => false,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_dst_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if dst_type == instr_type {
+ return None;
+ }
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
+ if instr_type.size_of() == dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else if instr_type.size_of() < dst_type.size_of() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_src_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(error_mismatched_type());
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(error_mismatched_type()),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if src_type == instr_type {
+ return None;
+ }
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= src_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
index b82d3c5..0e233ed 100644
--- a/ptx/src/pass/mod.rs
+++ b/ptx/src/pass/mod.rs
@@ -27,8 +27,10 @@ mod expand_operands;
mod extract_globals;
mod fix_special_registers;
mod fix_special_registers2;
+mod hoist_globals;
mod insert_explicit_load_store;
mod insert_implicit_conversions;
+mod insert_implicit_conversions2;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_identifiers2;
@@ -67,11 +69,13 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
})?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
- let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
+ todo!()
+ /*
+ let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
- })
+ }) */
}
pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
@@ -82,10 +86,17 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
- let directives = expand_operands::run(&mut flat_resolver, directives)?;
+ let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
+ expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
- todo!()
+ let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
+ let directives = hoist_globals::run(directives)?;
+ let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
+ Ok(Module {
+ llvm_ir,
+ kernel_info: HashMap::new(),
+ })
}
fn translate_directive<'input, 'a>(
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 69dd206..e15d6ea 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -236,7 +236,7 @@ fn test_hip_assert<
output: &mut [Output],
) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
- let llvm_ir = pass::to_llvm_module(ast).unwrap();
+ let llvm_ir = pass::to_llvm_module2(ast).unwrap();
let name = CString::new(name)?;
let result =
run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?;