aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml4
-rw-r--r--ptx/Cargo.toml4
-rw-r--r--ptx/src/translate.rs189
3 files changed, 123 insertions, 74 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 9b5f261..e02e2fc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -17,5 +17,5 @@ members = [
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]
[patch.crates-io]
-rspirv = { git = 'https://github.com/vosen/rspirv', rev = '40f5aa4dedb0d9f1ec24bdd8b6019e01996d1d74' }
-spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '40f5aa4dedb0d9f1ec24bdd8b6019e01996d1d74' } \ No newline at end of file
+rspirv = { git = 'https://github.com/vosen/rspirv', rev = '9826e59a232c4a426482cda12f88d11bfda3ff9c' }
+spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '9826e59a232c4a426482cda12f88d11bfda3ff9c' }
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 1d51e8b..4087469 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -9,8 +9,8 @@ edition = "2018"
[dependencies]
lalrpop-util = "0.19"
regex = "1"
-rspirv = "0.6"
-spirv_headers = "~1.4.2"
+rspirv = "0.7"
+spirv_headers = "1.5"
quick-error = "1.2"
thiserror = "1.0"
bit-vec = "0.6"
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index de7de82..7efcaf6 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -161,7 +161,7 @@ impl From<ast::ScalarType> for SpirvScalarKey {
impl TypeWordMap {
fn new(b: &mut dr::Builder) -> TypeWordMap {
- let void = b.type_void();
+ let void = b.type_void(None);
TypeWordMap {
void: void,
complex: HashMap::<SpirvType, spirv::Word>::new(),
@@ -183,14 +183,14 @@ impl TypeWordMap {
.complex
.entry(SpirvType::Base(key))
.or_insert_with(|| match key {
- SpirvScalarKey::B8 => b.type_int(8, 0),
- SpirvScalarKey::B16 => b.type_int(16, 0),
- SpirvScalarKey::B32 => b.type_int(32, 0),
- SpirvScalarKey::B64 => b.type_int(64, 0),
- SpirvScalarKey::F16 => b.type_float(16),
- SpirvScalarKey::F32 => b.type_float(32),
- SpirvScalarKey::F64 => b.type_float(64),
- SpirvScalarKey::Pred => b.type_bool(),
+ SpirvScalarKey::B8 => b.type_int(None, 8, 0),
+ SpirvScalarKey::B16 => b.type_int(None, 16, 0),
+ SpirvScalarKey::B32 => b.type_int(None, 32, 0),
+ SpirvScalarKey::B64 => b.type_int(None, 64, 0),
+ SpirvScalarKey::F16 => b.type_float(None, 16),
+ SpirvScalarKey::F32 => b.type_float(None, 32),
+ SpirvScalarKey::F64 => b.type_float(None, 64),
+ SpirvScalarKey::Pred => b.type_bool(None),
SpirvScalarKey::F16x2 => todo!(),
})
}
@@ -210,7 +210,7 @@ impl TypeWordMap {
*self
.complex
.entry(t)
- .or_insert_with(|| b.type_vector(base, len as u32))
+ .or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
@@ -230,7 +230,7 @@ impl TypeWordMap {
*self
.complex
.entry(SpirvType::Array(typ, array_dimensions))
- .or_insert_with(|| b.type_array(base_type, length))
+ .or_insert_with(|| b.type_array(None, base_type, length))
}
SpirvType::Func(ref out_params, ref in_params) => {
let out_t = match out_params {
@@ -244,7 +244,7 @@ impl TypeWordMap {
*self
.complex
.entry(t)
- .or_insert_with(|| b.type_function(out_t, in_t))
+ .or_insert_with(|| b.type_function(None, out_t, in_t))
}
SpirvType::Struct(ref underlying) => {
let underlying_ids = underlying
@@ -254,7 +254,7 @@ impl TypeWordMap {
*self
.complex
.entry(t)
- .or_insert_with(|| b.type_struct(underlying_ids))
+ .or_insert_with(|| b.type_struct(None, underlying_ids))
}
}
}
@@ -371,7 +371,7 @@ impl TypeWordMap {
)
})
.collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
+ b.constant_composite(result_type, None, components.into_iter())
}
ast::Type::Array(typ, dims) => match dims.as_slice() {
[] => return Err(error_unreachable()),
@@ -388,7 +388,7 @@ impl TypeWordMap {
)
})
.collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
+ b.constant_composite(result_type, None, components.into_iter())
}
[first_dim, rest @ ..] => {
let result_type = self.get_or_add(
@@ -407,7 +407,7 @@ impl TypeWordMap {
)
})
.collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
+ b.constant_composite(result_type, None, components.into_iter())
}
},
ast::Type::Pointer(typ, state_space) => {
@@ -608,10 +608,12 @@ fn emit_directives<'input>(
builder.decorate(
*fn_id,
spirv::Decoration::LinkageAttributes,
- &[
+ [
dr::Operand::LiteralString(name.clone()),
dr::Operand::LinkageType(spirv::LinkageType::Import),
- ],
+ ]
+ .iter()
+ .cloned(),
);
}
}
@@ -1042,7 +1044,9 @@ fn emit_builtins(
builder.decorate(
id,
spirv::Decoration::BuiltIn,
- &[dr::Operand::BuiltIn(reg.get_builtin())],
+ [dr::Operand::BuiltIn(reg.get_builtin())]
+ .iter()
+ .cloned(),
);
}
}
@@ -1137,13 +1141,7 @@ fn emit_function_header<'a>(
*/
for input in &func_decl.input {
let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
- let inst = dr::Instruction::new(
- spirv::Op::FunctionParameter,
- Some(result_type),
- Some(input.name),
- Vec::new(),
- );
- builder.function.as_mut().unwrap().parameters.push(inst);
+ builder.function_parameter(Some(input.name), result_type)?;
}
Ok(())
}
@@ -2750,13 +2748,13 @@ fn emit_function_body_ops(
for s in func {
match s {
Statement::Label(id) => {
- if builder.block.is_some() {
+ if builder.selected_block().is_some() {
builder.branch(*id)?;
}
builder.begin_block(Some(*id))?;
}
_ => {
- if builder.block.is_none() && builder.function.is_some() {
+ if builder.selected_block().is_none() && builder.selected_function().is_some() {
builder.begin_block(None)?;
}
}
@@ -2880,7 +2878,12 @@ fn emit_function_body_ops(
}
Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
- builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
+ builder.branch_conditional(
+ bra.predicate,
+ bra.if_true,
+ bra.if_false,
+ iter::empty(),
+ )?;
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?,
@@ -2902,7 +2905,9 @@ fn emit_function_body_ops(
Some(spirv::MemoryAccess::ALIGNED),
[dr::Operand::LiteralInt32(
ast::Type::from(data.typ.clone()).size_of() as u32,
- )],
+ )]
+ .iter()
+ .cloned(),
)?;
}
ast::Instruction::St(data, arg) => {
@@ -2915,7 +2920,9 @@ fn emit_function_body_ops(
Some(spirv::MemoryAccess::ALIGNED),
[dr::Operand::LiteralInt32(
ast::Type::from(data.typ.clone()).size_of() as u32,
- )],
+ )]
+ .iter()
+ .cloned(),
)?;
}
// SPIR-V does not support ret as guaranteed-converged
@@ -3130,7 +3137,7 @@ fn emit_function_body_ops(
Some(a.dst),
opencl,
spirv::CLOp::native_rsqrt as spirv::Word,
- &[a.src],
+ [dr::Operand::IdRef(a.src)].iter().cloned(),
)?;
}
ast::Instruction::Neg(details, arg) => {
@@ -3149,7 +3156,7 @@ fn emit_function_body_ops(
Some(arg.dst),
opencl,
spirv::CLOp::sin as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
ast::Instruction::Cos { arg, .. } => {
@@ -3159,7 +3166,7 @@ fn emit_function_body_ops(
Some(arg.dst),
opencl,
spirv::CLOp::cos as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
ast::Instruction::Lg2 { arg, .. } => {
@@ -3169,7 +3176,7 @@ fn emit_function_body_ops(
Some(arg.dst),
opencl,
spirv::CLOp::log2 as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
ast::Instruction::Ex2 { arg, .. } => {
@@ -3179,7 +3186,7 @@ fn emit_function_body_ops(
Some(arg.dst),
opencl,
spirv::CLOp::exp2 as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
ast::Instruction::Clz { typ, arg } => {
@@ -3189,7 +3196,7 @@ fn emit_function_body_ops(
Some(arg.dst),
opencl,
spirv::CLOp::clz as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
ast::Instruction::Brev { typ, arg } => {
@@ -3248,12 +3255,12 @@ fn emit_function_body_ops(
result_ptr_type,
None,
details.arg.src1,
- &[index_spirv],
+ [index_spirv].iter().copied(),
)?
}
None => details.arg.src1,
};
- builder.store(dst_ptr, details.arg.src2, None, [])?;
+ builder.store(dst_ptr, details.arg.src2, None, iter::empty())?;
}
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
@@ -3282,7 +3289,7 @@ fn emit_function_body_ops(
None,
ptr_src_u8,
*offset_src,
- &[],
+ iter::empty(),
)?;
builder.bitcast(result_type, Some(*dst), temp)?;
}
@@ -3294,7 +3301,7 @@ fn emit_function_body_ops(
scalar_type,
Some(*dst_id),
repack.packed,
- &[index as u32],
+ [index as u32].iter().copied(),
)?;
}
} else {
@@ -3312,7 +3319,7 @@ fn emit_function_body_ops(
None,
*src_id,
temp_vec,
- &[index as u32],
+ [index as u32].iter().copied(),
)?;
}
builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
@@ -3371,7 +3378,7 @@ fn emit_sqrt(
Some(a.dst),
opencl,
ocl_op as spirv::Word,
- &[a.src],
+ [dr::Operand::IdRef(a.src)].iter().cloned(),
)?;
emit_rounding_decoration(builder, a.dst, rounding);
Ok(())
@@ -3383,9 +3390,11 @@ fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind:
builder.decorate(
dst,
spirv::Decoration::FPFastMathMode,
- &[dr::Operand::FPFastMathMode(
+ [dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
- )],
+ )]
+ .iter()
+ .cloned(),
);
}
ast::DivFloatKind::Rounding(rnd) => {
@@ -3521,9 +3530,11 @@ fn emit_rcp(
builder.decorate(
a.dst,
spirv::Decoration::FPFastMathMode,
- &[dr::Operand::FPFastMathMode(
+ [dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
- )],
+ )]
+ .iter()
+ .cloned(),
);
Ok(())
}
@@ -3570,7 +3581,7 @@ fn emit_variable(
builder.decorate(
var.name,
spirv::Decoration::Alignment,
- &[dr::Operand::LiteralInt32(align)],
+ [dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
Ok(())
@@ -3595,7 +3606,13 @@ fn emit_mad_uint(
Some(arg.dst),
opencl,
spirv::CLOp::u_mad_hi as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
+ [
+ dr::Operand::IdRef(arg.src1),
+ dr::Operand::IdRef(arg.src2),
+ dr::Operand::IdRef(arg.src3),
+ ]
+ .iter()
+ .cloned(),
)?;
}
ast::MulIntControl::Wide => todo!(),
@@ -3622,7 +3639,13 @@ fn emit_mad_sint(
Some(arg.dst),
opencl,
spirv::CLOp::s_mad_hi as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
+ [
+ dr::Operand::IdRef(arg.src1),
+ dr::Operand::IdRef(arg.src2),
+ dr::Operand::IdRef(arg.src3),
+ ]
+ .iter()
+ .cloned(),
)?;
}
ast::MulIntControl::Wide => todo!(),
@@ -3643,7 +3666,13 @@ fn emit_mad_float(
Some(arg.dst),
opencl,
spirv::CLOp::mad as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
+ [
+ dr::Operand::IdRef(arg.src1),
+ dr::Operand::IdRef(arg.src2),
+ dr::Operand::IdRef(arg.src3),
+ ]
+ .iter()
+ .cloned(),
)?;
Ok(())
}
@@ -3690,7 +3719,9 @@ fn emit_min(
Some(arg.dst),
opencl,
cl_op as spirv::Word,
- [arg.src1, arg.src2],
+ [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)]
+ .iter()
+ .cloned(),
)?;
Ok(())
}
@@ -3713,7 +3744,9 @@ fn emit_max(
Some(arg.dst),
opencl,
cl_op as spirv::Word,
- [arg.src1, arg.src2],
+ [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)]
+ .iter()
+ .cloned(),
)?;
Ok(())
}
@@ -3740,7 +3773,7 @@ fn emit_cvt(
Some(arg.dst),
opencl,
spirv::CLOp::rint as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
Some(ast::RoundingMode::Zero) => {
@@ -3749,7 +3782,7 @@ fn emit_cvt(
Some(arg.dst),
opencl,
spirv::CLOp::trunc as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
Some(ast::RoundingMode::NegativeInf) => {
@@ -3758,7 +3791,7 @@ fn emit_cvt(
Some(arg.dst),
opencl,
spirv::CLOp::floor as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
Some(ast::RoundingMode::PositiveInf) => {
@@ -3767,7 +3800,7 @@ fn emit_cvt(
Some(arg.dst),
opencl,
spirv::CLOp::ceil as u32,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
None => {
@@ -3851,7 +3884,7 @@ fn emit_cvt(
fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) {
if saturate {
- builder.decorate(dst, spirv::Decoration::SaturatedConversion, []);
+ builder.decorate(dst, spirv::Decoration::SaturatedConversion, iter::empty());
}
}
@@ -3864,7 +3897,7 @@ fn emit_rounding_decoration(
builder.decorate(
dst,
spirv::Decoration::FPRoundingMode,
- [rounding.to_spirv()],
+ [rounding.to_spirv()].iter().cloned(),
);
}
}
@@ -3990,7 +4023,9 @@ fn emit_mul_sint(
Some(arg.dst),
opencl,
spirv::CLOp::s_mul_hi as spirv::Word,
- [arg.src1, arg.src2],
+ [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)]
+ .iter()
+ .cloned(),
)?;
}
ast::MulIntControl::Wide => {
@@ -4037,7 +4072,9 @@ fn emit_mul_uint(
Some(arg.dst),
opencl,
spirv::CLOp::u_mul_hi as spirv::Word,
- [arg.src1, arg.src2],
+ [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)]
+ .iter()
+ .cloned(),
)?;
}
ast::MulIntControl::Wide => {
@@ -4075,10 +4112,16 @@ fn struct2_bitcast_to_wide(
dst_type_id: spirv::Word,
src: spirv::Word,
) -> Result<(), dr::Error> {
- let low_bits = builder.composite_extract(instruction_type, None, src, [0])?;
- let high_bits = builder.composite_extract(instruction_type, None, src, [1])?;
+ let low_bits =
+ builder.composite_extract(instruction_type, None, src, [0].iter().copied())?;
+ let high_bits =
+ builder.composite_extract(instruction_type, None, src, [1].iter().copied())?;
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
- let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?;
+ let vector = builder.composite_construct(
+ vector_type,
+ None,
+ [low_bits, high_bits].iter().copied(),
+ )?;
builder.bitcast(dst_type_id, Some(dst), vector)?;
Ok(())
}
@@ -4102,7 +4145,7 @@ fn emit_abs(
Some(arg.dst),
opencl,
cl_abs as spirv::Word,
- [arg.src],
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
Ok(())
}
@@ -4249,12 +4292,18 @@ fn emit_load_var(
_ => return Err(TranslateError::MismatchedType),
};
let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
- let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?;
+ let vector_temp = builder.load(
+ vector_type_spirv,
+ None,
+ details.arg.src,
+ None,
+ iter::empty(),
+ )?;
builder.composite_extract(
result_type,
Some(details.arg.dst),
vector_temp,
- &[index as u32],
+ [index as u32].iter().copied(),
)?;
}
Some((index, None)) => {
@@ -4271,9 +4320,9 @@ fn emit_load_var(
result_ptr_type,
None,
details.arg.src,
- &[index_spirv],
+ [index_spirv].iter().copied(),
)?;
- builder.load(result_type, Some(details.arg.dst), src, None, [])?;
+ builder.load(result_type, Some(details.arg.dst), src, None, iter::empty())?;
}
None => {
builder.load(
@@ -4281,7 +4330,7 @@ fn emit_load_var(
Some(details.arg.dst),
details.arg.src,
None,
- [],
+ iter::empty(),
)?;
}
};