summaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs73
1 files changed, 35 insertions, 38 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7726040..22e16ff 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -299,6 +299,7 @@ fn emit_function_header<'a>(
pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, TranslateError> {
let module = to_spirv_module(ast)?;
+ eprintln!("{}", module.disassemble());
Ok(module.assemble())
}
@@ -309,6 +310,7 @@ fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::Kernel);
builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Int8);
+ builder.capability(spirv::Capability::Float64);
}
fn emit_extensions(_: &mut dr::Builder) {}
@@ -990,8 +992,13 @@ fn insert_implicit_conversions(
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?,
Statement::Instruction(inst) => match inst {
ast::Instruction::Ld(ld, arg) => {
- let pre_conv =
- get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src)?;
+ let pre_conv = get_implicit_conversions_ld_src(
+ id_def,
+ ld.typ,
+ ld.state_space,
+ arg.src,
+ false,
+ )?;
let post_conv = get_implicit_conversions_ld_dst(
id_def,
ld.typ,
@@ -1024,8 +1031,11 @@ fn insert_implicit_conversions(
st.typ,
st.state_space.to_ld_ss(),
arg.src1,
+ true,
)?;
- let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param {
+ let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param
+ || st.state_space == ast::StStateSpace::Local
+ {
(Vec::new(), post_conv)
} else {
(post_conv, Vec::new())
@@ -1667,7 +1677,7 @@ fn emit_implicit_conversion(
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.to));
if from_parts.scalar_kind != ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float
{
@@ -1714,7 +1724,7 @@ fn emit_implicit_conversion(
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
@@ -2409,7 +2419,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
- let is_param = d.state_space == ast::LdStateSpace::Param;
+ let is_param = d.state_space == ast::LdStateSpace::Param
+ || d.state_space == ast::LdStateSpace::Local;
ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?)
}
ast::Instruction::Mov(d, a) => {
@@ -2432,7 +2443,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?)
+ }
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -2459,7 +2472,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
- let is_param = d.state_space == ast::StStateSpace::Param;
+ let is_param = d.state_space == ast::StStateSpace::Param
+ || d.state_space == ast::StStateSpace::Local;
ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?)
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
@@ -3419,8 +3433,8 @@ fn get_implicit_conversions_ld_dst<
Ok(Some(ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
- from: if !in_reverse { dst_type } else { instr_type },
- to: if !in_reverse { instr_type } else { dst_type },
+ from: if !in_reverse { instr_type } else { dst_type },
+ to: if !in_reverse { dst_type } else { instr_type },
kind: conv,
}))
} else {
@@ -3433,6 +3447,7 @@ fn get_implicit_conversions_ld_src(
instr_type: ast::Type,
state_space: ast::LdStateSpace,
src: spirv::Word,
+ in_reverse_param_local: bool,
) -> Result<Vec<ImplicitConversion>, TranslateError> {
let src_type = id_def.get_typed(src)?;
match state_space {
@@ -3442,8 +3457,16 @@ fn get_implicit_conversions_ld_src(
ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
- from: src_type,
- to: instr_type,
+ from: if !in_reverse_param_local {
+ src_type
+ } else {
+ instr_type
+ },
+ to: if !in_reverse_param_local {
+ instr_type
+ } else {
+ src_type
+ },
kind: ConversionKind::Default,
};
1
@@ -3512,32 +3535,6 @@ fn insert_conversion_src(
temp_src
}
-/*
-fn insert_with_implicit_conversion_dst<
- T,
- ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option<ConversionKind>,
- Setter: Fn(&mut T) -> &mut spirv::Word,
- ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
->(
- func: &mut Vec<ExpandedStatement>,
- instr_type: ast::Type,
- id_def: &mut NumericIdResolver,
- should_convert: ShouldConvert,
- mut t: T,
- setter: Setter,
- to_inst: ToInstruction,
-) {
- let dst = setter(&mut t);
- let dst_type = id_def.get_type(*dst);
- let dst_coercion = should_convert(dst_type.unwrap(), instr_type)
- .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv));
- func.push(Statement::Instruction(to_inst(t)));
- if let Some(conv) = dst_coercion {
- func.push(conv);
- }
-}
-*/
-
#[must_use]
fn get_conversion_dst(
id_def: &mut MutableNumericIdResolver,