diff options
Diffstat (limited to 'ptx/src/pass/expand_operands.rs')
-rw-r--r-- | ptx/src/pass/expand_operands.rs | 50 |
1 files changed, 29 insertions, 21 deletions
diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index 3dabf40..f2de786 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_member(
&mut self,
- vector_src: SpirvWord,
+ vector_ident: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
- if is_dst {
- return Err(error_mismatched_type());
- }
- let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
+ let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
@@ -206,35 +203,46 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
- self.result.push(Statement::VectorAccess(VectorAccess {
- scalar_type,
- vector_width,
- dst: temporary,
- src: vector_src,
- member: member,
- }));
+ if is_dst {
+ self.post_stmts.push(Statement::VectorWrite(VectorWrite {
+ scalar_type,
+ vector_width,
+ vector_dst: vector_ident,
+ vector_src: vector_ident,
+ scalar_src: temporary,
+ member,
+ }));
+ } else {
+ self.result.push(Statement::VectorRead(VectorRead {
+ scalar_type,
+ vector_width,
+ scalar_dst: temporary,
+ vector_src: vector_ident,
+ member,
+ }));
+ }
Ok(temporary)
}
fn vec_pack(
&mut self,
- vecs: Vec<SpirvWord>,
+ vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
- let (scalar_t, state_space) = match type_space {
- Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
+ let (width, scalar_t, state_space) = match type_space {
+ Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()),
};
- let temp_vec = self
+ let temporary_vector = self
.resolver
- .register_unnamed(Some((scalar_t.into(), state_space)));
+ .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
- packed: temp_vec,
- unpacked: vecs,
+ packed: temporary_vector,
+ unpacked: vector_elements,
relaxed_type_check,
});
if is_dst {
@@ -242,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { } else {
self.result.push(statement);
}
- Ok(temp_vec)
+ Ok(temporary_vector)
}
}
@@ -273,7 +281,7 @@ impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, Translate fn visit_ident(
&mut self,
- name: <TypedOperand as ast::Operand>::Ident,
+ name: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
|