use super::*; use half::f16; use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; use std::{ collections::{HashMap, HashSet}, ffi::CString, mem, }; pub(super) fn run<'input>( mut builder: dr::Builder, id_defs: &GlobalStringIdResolver<'input>, call_map: MethodsCallMap<'input>, denorm_information: HashMap< ptx_parser::MethodName, HashMap, >, directives: Vec>, ) -> Result<(dr::Module, HashMap, CString), TranslateError> { builder.set_version(1, 3); emit_capabilities(&mut builder); emit_extensions(&mut builder); let opencl_id = emit_opencl_import(&mut builder); emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); //emit_builtins(&mut builder, &mut map, &id_defs); let mut kernel_info = HashMap::new(); let (build_options, should_flush_denorms) = emit_denorm_build_string(&call_map, &denorm_information); let (directives, globals_use_map) = get_globals_use_map(directives); emit_directives( &mut builder, &mut map, &id_defs, opencl_id, should_flush_denorms, &call_map, globals_use_map, directives, &mut kernel_info, )?; Ok((builder.module(), kernel_info, build_options)) } fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::GenericPointer); builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Kernel); builder.capability(spirv::Capability::Int8); builder.capability(spirv::Capability::Int16); builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Float16); builder.capability(spirv::Capability::Float64); builder.capability(spirv::Capability::DenormFlushToZero); // TODO: re-enable when Intel float control extension works //builder.capability(spirv::Capability::FunctionFloatControlINTEL); } // http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html fn emit_extensions(builder: &mut dr::Builder) { // TODO: re-enable when Intel float control extension works //builder.extension("SPV_INTEL_float_controls2"); builder.extension("SPV_KHR_float_controls"); builder.extension("SPV_KHR_no_integer_wrap_decoration"); } fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { builder.ext_inst_import("OpenCL.std") } fn emit_memory_model(builder: &mut dr::Builder) { builder.memory_model( spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL, ); } struct TypeWordMap { void: spirv::Word, complex: HashMap, constants: HashMap<(SpirvType, u64), SpirvWord>, } impl TypeWordMap { fn new(b: &mut dr::Builder) -> TypeWordMap { let void = b.type_void(None); TypeWordMap { void: void, complex: HashMap::::new(), constants: HashMap::new(), } } fn void(&self) -> spirv::Word { self.void } fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { let key: SpirvScalarKey = t.into(); self.get_or_add_spirv_scalar(b, key) } fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { SpirvWord(match key { SpirvScalarKey::B8 => b.type_int(None, 8, 0), SpirvScalarKey::B16 => b.type_int(None, 16, 0), SpirvScalarKey::B32 => b.type_int(None, 32, 0), SpirvScalarKey::B64 => b.type_int(None, 64, 0), SpirvScalarKey::F16 => b.type_float(None, 16), SpirvScalarKey::F32 => b.type_float(None, 32), SpirvScalarKey::F64 => b.type_float(None, 64), SpirvScalarKey::Pred => b.type_bool(None), SpirvScalarKey::F16x2 => todo!(), }) }) } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { match t { SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), SpirvType::Pointer(ref typ, storage) => { let base = self.get_or_add(b, *typ.clone()); *self .complex .entry(t) .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) } SpirvType::Vector(typ, len) => { let base = self.get_or_add_spirv_scalar(b, typ); *self .complex .entry(t) .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) } SpirvType::Array(typ, array_dimensions) => { let (base_type, length) = match &*array_dimensions { &[] => { return self.get_or_add(b, SpirvType::Base(typ)); } &[len] => { let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self.get_or_add_spirv_scalar(b, typ); let len_const = b.constant_u32(u32_type.0, None, len); (base, len_const) } array_dimensions => { let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); (base, len_const) } }; *self .complex .entry(SpirvType::Array(typ, array_dimensions)) .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) } SpirvType::Func(ref out_params, ref in_params) => { let out_t = match out_params { Some(p) => self.get_or_add(b, *p.clone()), None => SpirvWord(self.void()), }; let in_t = in_params .iter() .map(|t| self.get_or_add(b, t.clone()).0) .collect::>(); *self .complex .entry(t) .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) } SpirvType::Struct(ref underlying) => { let underlying_ids = underlying .iter() .map(|t| self.get_or_add_spirv_scalar(b, *t).0) .collect::>(); *self .complex .entry(t) .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) } } } fn get_or_add_fn( &mut self, b: &mut dr::Builder, in_params: impl Iterator, mut out_params: impl ExactSizeIterator, ) -> (SpirvWord, SpirvWord) { let (out_args, out_spirv_type) = if out_params.len() == 0 { (None, SpirvWord(self.void())) } else if out_params.len() == 1 { let arg_as_key = out_params.next().unwrap(); ( Some(Box::new(arg_as_key.clone())), self.get_or_add(b, arg_as_key), ) } else { // TODO: support multiple return values todo!() }; ( out_spirv_type, self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), ) } fn get_or_add_constant( &mut self, b: &mut dr::Builder, typ: &ast::Type, init: &[u8], ) -> Result { Ok(match typ { ast::Type::Scalar(t) => match t { ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v as u32), ), ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v as u32), ), ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v), ), ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self .get_or_add_constant_single::( b, *t, init, |v| v, |b, result_type, v| b.constant_u64(result_type, None, v), ), ast::ScalarType::F16 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u16>(v) } as u64, |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), ), ast::ScalarType::F32 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u32>(v) } as u64, |b, result_type, v| b.constant_f32(result_type, None, v), ), ast::ScalarType::F64 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u64>(v) }, |b, result_type, v| b.constant_f64(result_type, None, v), ), ast::ScalarType::F16x2 => return Err(TranslateError::Todo), ast::ScalarType::Pred => self.get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| { if v == 0 { b.constant_false(result_type, None) } else { b.constant_true(result_type, None) } }, ), ast::ScalarType::S16x2 | ast::ScalarType::U16x2 | ast::ScalarType::BF16 | ast::ScalarType::BF16x2 | ast::ScalarType::B128 => todo!(), }, ast::Type::Vector(len, typ) => { let result_type = self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); let size_of_t = typ.size_of(); let components = (0..*len) .map(|x| { Ok::<_, TranslateError>( self.get_or_add_constant( b, &ast::Type::Scalar(*typ), &init[((size_of_t as usize) * (x as usize))..], )? .0, ) }) .collect::, _>>()?; SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) } ast::Type::Array(_, typ, dims) => match dims.as_slice() { [] => return Err(error_unreachable()), [dim] => { let result_type = self .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); let size_of_t = typ.size_of(); let components = (0..*dim) .map(|x| { Ok::<_, TranslateError>( self.get_or_add_constant( b, &ast::Type::Scalar(*typ), &init[((size_of_t as usize) * (x as usize))..], )? .0, ) }) .collect::, _>>()?; SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) } [first_dim, rest @ ..] => { let result_type = self.get_or_add( b, SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), ); let size_of_t = rest .iter() .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); let components = (0..*first_dim) .map(|x| { Ok::<_, TranslateError>( self.get_or_add_constant( b, &ast::Type::Array(None, *typ, rest.to_vec()), &init[((size_of_t as usize) * (x as usize))..], )? .0, ) }) .collect::, _>>()?; SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) } }, ast::Type::Pointer(..) => return Err(error_unreachable()), }) } fn get_or_add_constant_single< T: Copy, CastAsU64: FnOnce(T) -> u64, InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, >( &mut self, b: &mut dr::Builder, key: ast::ScalarType, init: &[u8], cast: CastAsU64, f: InsertConstant, ) -> SpirvWord { let value = unsafe { *(init.as_ptr() as *const T) }; let value_64 = cast(value); let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); match self.constants.get(&ht_key) { Some(value) => *value, None => { let spirv_type = self.get_or_add_scalar(b, key); let result = SpirvWord(f(b, spirv_type.0, value)); self.constants.insert(ht_key, result); result } } } } #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), Vector(SpirvScalarKey, u8), Array(SpirvScalarKey, Vec), Pointer(Box, spirv::StorageClass), Func(Option>, Vec), Struct(Vec), } impl SpirvType { fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( Box::new(SpirvType::Base(pointer_t.into())), space_to_spirv(space), ), } } fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { let key = Self::new(t); SpirvType::Pointer(Box::new(key), outer_space) } } impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) } } // SPIR-V integer type definitions are signless, more below: // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvScalarKey { B8, B16, B32, B64, F16, F32, F64, Pred, F16x2, } impl From for SpirvScalarKey { fn from(t: ast::ScalarType) -> Self { match t { ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { SpirvScalarKey::B16 } ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { SpirvScalarKey::B32 } ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { SpirvScalarKey::B64 } ast::ScalarType::F16 => SpirvScalarKey::F16, ast::ScalarType::F32 => SpirvScalarKey::F32, ast::ScalarType::F64 => SpirvScalarKey::F64, ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, ast::ScalarType::Pred => SpirvScalarKey::Pred, ast::ScalarType::S16x2 | ast::ScalarType::U16x2 | ast::ScalarType::BF16 | ast::ScalarType::BF16x2 | ast::ScalarType::B128 => todo!(), } } } fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { match this { ast::StateSpace::Const => spirv::StorageClass::UniformConstant, ast::StateSpace::Generic => spirv::StorageClass::Generic, ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, ast::StateSpace::Local => spirv::StorageClass::Function, ast::StateSpace::Shared => spirv::StorageClass::Workgroup, ast::StateSpace::Param => spirv::StorageClass::Function, ast::StateSpace::Reg => spirv::StorageClass::Function, ast::StateSpace::Sreg => spirv::StorageClass::Input, ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc | ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta => todo!(), } } // TODO: remove this once we have pef-function support for denorms fn emit_denorm_build_string<'input>( call_map: &MethodsCallMap, denorm_information: &HashMap< ast::MethodName<'input, SpirvWord>, HashMap, >, ) -> (CString, bool) { let denorm_counts = denorm_information .iter() .map(|(method, meth_denorm)| { let f16_count = meth_denorm .get(&(mem::size_of::() as u8)) .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) .1; let f32_count = meth_denorm .get(&(mem::size_of::() as u8)) .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) .1; (method, (f16_count + f32_count)) }) .collect::>(); let mut flush_over_preserve = 0; for (kernel, children) in call_map.kernels() { flush_over_preserve += *denorm_counts .get(&ast::MethodName::Kernel(kernel)) .unwrap_or(&0); for child_fn in children { flush_over_preserve += *denorm_counts .get(&ast::MethodName::Func(*child_fn)) .unwrap_or(&0); } } if flush_over_preserve > 0 { ( CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), true, ) } else { (CString::new("-ze-take-global-address").unwrap(), false) } } fn get_globals_use_map<'input>( directives: Vec>, ) -> ( Vec>, HashMap, HashSet>, ) { let mut known_globals = HashSet::new(); for directive in directives.iter() { match directive { Directive::Variable(_, ast::Variable { name, .. }) => { known_globals.insert(*name); } Directive::Method(..) => {} } } let mut symbol_uses_map = HashMap::new(); let directives = directives .into_iter() .map(|directive| match directive { Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, Directive::Method(Function { func_decl, body: Some(mut statements), globals, import_as, tuning, linkage, }) => { let method_name = func_decl.borrow().name; statements = statements .into_iter() .map(|statement| { statement.visit_map( &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { if known_globals.contains(&symbol) { multi_hash_map_append( &mut symbol_uses_map, method_name, symbol, ); } Ok::<_, TranslateError>(symbol) }, ) }) .collect::, _>>() .unwrap(); Directive::Method(Function { func_decl, body: Some(statements), globals, import_as, tuning, linkage, }) } }) .collect::>(); (directives, symbol_uses_map) } fn emit_directives<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, should_flush_denorms: bool, call_map: &MethodsCallMap<'input>, globals_use_map: HashMap, HashSet>, directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { let empty_body = Vec::new(); for d in directives.iter() { match d { Directive::Variable(linking, var) => { emit_variable(builder, map, id_defs, *linking, &var)?; } Directive::Method(f) => { let f_body = match &f.body { Some(f) => f, None => { if f.linkage.contains(ast::LinkingDirective::EXTERN) { &empty_body } else { continue; } } }; for var in f.globals.iter() { emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; } let func_decl = (*f.func_decl).borrow(); let fn_id = emit_function_header( builder, map, &id_defs, &*func_decl, call_map, &globals_use_map, kernel_info, )?; if matches!(func_decl.name, ast::MethodName::Kernel(_)) { if should_flush_denorms { builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::DenormFlushToZero, [16], ); builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::DenormFlushToZero, [32], ); builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::DenormFlushToZero, [64], ); } // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::ContractionOff, [], ); for t in f.tuning.iter() { match *t { ast::TuningDirective::MaxNtid(nx, ny, nz) => { builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, [nx, ny, nz], ); } ast::TuningDirective::ReqNtid(nx, ny, nz) => { builder.execution_mode( fn_id.0, spirv_headers::ExecutionMode::LocalSize, [nx, ny, nz], ); } // Too architecture specific ast::TuningDirective::MaxNReg(..) | ast::TuningDirective::MinNCtaPerSm(..) => {} } } } emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; emit_function_linkage(builder, id_defs, f, fn_id)?; builder.select_block(None)?; builder.end_function()?; } } } Ok(()) } fn emit_variable<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, linking: ast::LinkingDirective, var: &ast::Variable, ) -> Result<(), TranslateError> { let (must_init, st_class) = match var.state_space { ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { (false, spirv::StorageClass::Function) } ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), ast::StateSpace::Generic => todo!(), ast::StateSpace::Sreg => todo!(), ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc | ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some( map.get_or_add_constant( builder, &ast::Type::from(var.v_type.clone()), &*var.array_init, )? .0, ) } else if must_init { let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); Some(builder.constant_null(type_id.0, None)) } else { None }; let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); if let Some(align) = var.align { builder.decorate( var.name.0, spirv::Decoration::Alignment, [dr::Operand::LiteralInt32(align)].iter().cloned(), ); } if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN) { emit_linking_decoration(builder, id_defs, None, var.name, linking); } Ok(()) } fn emit_function_header<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'input>, func_decl: &ast::MethodDeclaration<'input, SpirvWord>, call_map: &MethodsCallMap<'input>, globals_use_map: &HashMap, HashSet>, kernel_info: &mut HashMap, ) -> Result { if let ast::MethodName::Kernel(name) = func_decl.name { let args_lens = func_decl .input_arguments .iter() .map(|param| { ( type_size_of(¶m.v_type), matches!(param.v_type, ast::Type::Pointer(..)), ) }) .collect(); kernel_info.insert( name.to_string(), KernelInfo { arguments_sizes: args_lens, uses_shared_mem: func_decl.shared_mem.is_some(), }, ); } let (ret_type, func_type) = get_function_type( builder, map, effective_input_arguments(func_decl).map(|(_, typ)| typ), &func_decl.return_arguments, ); let fn_id = match func_decl.name { ast::MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; let interface = globals_use_map .get(&ast::MethodName::Kernel(name)) .into_iter() .flatten() .copied() .chain({ call_map .get_kernel_children(name) .copied() .flat_map(|subfunction| { globals_use_map .get(&ast::MethodName::Func(subfunction)) .into_iter() .flatten() .copied() }) .into_iter() }) .map(|word| word.0) .collect::>(); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); fn_id } ast::MethodName::Func(name) => name, }; builder.begin_function( ret_type.0, Some(fn_id.0), spirv::FunctionControl::NONE, func_type.0, )?; for (name, typ) in effective_input_arguments(func_decl) { let result_type = map.get_or_add(builder, typ); builder.function_parameter(Some(name.0), result_type.0)?; } Ok(fn_id) } pub fn type_size_of(this: &ast::Type) -> usize { match this { ast::Type::Scalar(typ) => typ.size_of() as usize, ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), ast::Type::Array(_, typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), ast::Type::Pointer(..) => mem::size_of::(), } } fn emit_function_body_ops<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl: spirv::Word, func: &[ExpandedStatement], ) -> Result<(), TranslateError> { for s in func { match s { Statement::Label(id) => { if builder.selected_block().is_some() { builder.branch(id.0)?; } builder.begin_block(Some(id.0))?; } _ => { if builder.selected_block().is_none() && builder.selected_function().is_some() { builder.begin_block(None)?; } } } match s { Statement::Label(_) => (), Statement::Variable(var) => { emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; } Statement::Constant(cnst) => { let typ_id = map.get_or_add_scalar(builder, cnst.typ); match (cnst.typ, cnst.value) { (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); } (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); } (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); } (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); } (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); } (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); } (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); } (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); } (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); } (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); } (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); } (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); } (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); } (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); } (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); } (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); } (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { builder.constant_f32( typ_id.0, Some(cnst.dst.0), f16::from_f32(value).to_f32(), ); } (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); } (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); } (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { builder.constant_f32( typ_id.0, Some(cnst.dst.0), f16::from_f64(value).to_f32(), ); } (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); } (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); } (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; if value == 0 { builder.constant_false(bool_type, Some(cnst.dst.0)); } else { builder.constant_true(bool_type, Some(cnst.dst.0)); } } (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; if value == 0 { builder.constant_false(bool_type, Some(cnst.dst.0)); } else { builder.constant_true(bool_type, Some(cnst.dst.0)); } } _ => return Err(error_mismatched_type()), } } Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { builder.branch_conditional( bra.predicate.0, bra.if_true.0, bra.if_false.0, iter::empty(), )?; } Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { // TODO: implement properly let zero = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U64), &vec_repr(0u64), )?; let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); builder.copy_object(result_type.0, Some(dst.0), zero.0)?; } Statement::Instruction(inst) => match inst { ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), ast::Instruction::Call { data, arguments } => { let (result_type, result_id) = match (&*data.return_arguments, &*arguments.return_arguments) { ([(type_, space)], [id]) => { if *space != ast::StateSpace::Reg { return Err(error_unreachable()); } ( map.get_or_add(builder, SpirvType::new(type_.clone())).0, Some(id.0), ) } ([], []) => (map.void(), None), _ => todo!(), }; let arg_list = arguments .input_arguments .iter() .map(|id| id.0) .collect::>(); builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; } ast::Instruction::Abs { data, arguments } => { emit_abs(builder, map, opencl, data, arguments)? } // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra { arguments, .. } => { builder.branch(arguments.src.0)?; } ast::Instruction::Ld { data, arguments } => { let mem_access = match data.qualifier { ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, _ => return Err(TranslateError::Todo), }; let result_type = map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); builder.load( result_type.0, Some(arguments.dst.0), arguments.src.0, Some(mem_access | spirv::MemoryAccess::ALIGNED), [dr::Operand::LiteralInt32( type_size_of(&ast::Type::from(data.typ.clone())) as u32, )] .iter() .cloned(), )?; } ast::Instruction::St { data, arguments } => { let mem_access = match data.qualifier { ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, _ => return Err(TranslateError::Todo), }; builder.store( arguments.src1.0, arguments.src2.0, Some(mem_access | spirv::MemoryAccess::ALIGNED), [dr::Operand::LiteralInt32( type_size_of(&ast::Type::from(data.typ.clone())) as u32, )] .iter() .cloned(), )?; } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret { .. } => builder.ret()?, ast::Instruction::Mov { data, arguments } => { let result_type = map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; } ast::Instruction::Mul { data, arguments } => match data { ast::MulDetails::Integer { type_, control } => { emit_mul_int(builder, map, opencl, *type_, *control, arguments)? } ast::MulDetails::Float(ref ctr) => { emit_mul_float(builder, map, ctr, arguments)? } }, ast::Instruction::Add { data, arguments } => match data { ast::ArithDetails::Integer(desc) => { emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? } ast::ArithDetails::Float(desc) => { emit_add_float(builder, map, desc, arguments)? } }, ast::Instruction::Setp { data, arguments } => { if arguments.dst2.is_some() { todo!() } emit_setp(builder, map, data, arguments)?; } ast::Instruction::Not { data, arguments } => { let result_type = map.get_or_add(builder, SpirvType::from(*data)); let result_id = Some(arguments.dst.0); let operand = arguments.src; match data { ast::ScalarType::Pred => { logical_not(builder, result_type.0, result_id, operand.0) } _ => builder.not(result_type.0, result_id, operand.0), }?; } ast::Instruction::Shl { data, arguments } => { let full_type = ast::Type::Scalar(*data); let size_of = type_size_of(&full_type); let result_type = map.get_or_add(builder, SpirvType::new(full_type)); let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; builder.shift_left_logical( result_type.0, Some(arguments.dst.0), arguments.src1.0, offset_src, )?; } ast::Instruction::Shr { data, arguments } => { let full_type = ast::ScalarType::from(data.type_); let size_of = full_type.size_of(); let result_type = map.get_or_add_scalar(builder, full_type).0; let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; match data.kind { ptx_parser::RightShiftKind::Arithmetic => { builder.shift_right_arithmetic( result_type, Some(arguments.dst.0), arguments.src1.0, offset_src, )?; } ptx_parser::RightShiftKind::Logical => { builder.shift_right_logical( result_type, Some(arguments.dst.0), arguments.src1.0, offset_src, )?; } } } ast::Instruction::Cvt { data, arguments } => { emit_cvt(builder, map, opencl, data, arguments)?; } ast::Instruction::Cvta { data, arguments } => { // This would be only meaningful if const/slm/global pointers // had a different format than generic pointers, but they don't pretty much by ptx definition // Honestly, I have no idea why this instruction exists and is emitted by the compiler let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; } ast::Instruction::SetpBool { .. } => todo!(), ast::Instruction::Mad { data, arguments } => match data { ast::MadDetails::Integer { type_, control, saturate, } => { if *saturate { todo!() } if type_.kind() == ast::ScalarKind::Signed { emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? } else { emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? } } ast::MadDetails::Float(desc) => { emit_mad_float(builder, map, opencl, desc, arguments)? } }, ast::Instruction::Fma { data, arguments } => { emit_fma_float(builder, map, opencl, data, arguments)? } ast::Instruction::Or { data, arguments } => { let result_type = map.get_or_add_scalar(builder, *data).0; if *data == ast::ScalarType::Pred { builder.logical_or( result_type, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } else { builder.bitwise_or( result_type, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } } ast::Instruction::Sub { data, arguments } => match data { ast::ArithDetails::Integer(desc) => { emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; } ast::ArithDetails::Float(desc) => { emit_sub_float(builder, map, desc, arguments)?; } }, ast::Instruction::Min { data, arguments } => { emit_min(builder, map, opencl, data, arguments)?; } ast::Instruction::Max { data, arguments } => { emit_max(builder, map, opencl, data, arguments)?; } ast::Instruction::Rcp { data, arguments } => { emit_rcp(builder, map, opencl, data, arguments)?; } ast::Instruction::And { data, arguments } => { let result_type = map.get_or_add_scalar(builder, *data); if *data == ast::ScalarType::Pred { builder.logical_and( result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } else { builder.bitwise_and( result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } } ast::Instruction::Selp { data, arguments } => { let result_type = map.get_or_add_scalar(builder, *data); builder.select( result_type.0, Some(arguments.dst.0), arguments.src3.0, arguments.src1.0, arguments.src2.0, )?; } // TODO: implement named barriers ast::Instruction::Bar { data, arguments } => { let workgroup_scope = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(spirv::Scope::Workgroup as u32), )?; let barrier_semantics = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr( spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), )?; builder.control_barrier( workgroup_scope.0, workgroup_scope.0, barrier_semantics.0, )?; } ast::Instruction::Atom { data, arguments } => { emit_atom(builder, map, data, arguments)?; } ast::Instruction::AtomCas { data, arguments } => { let result_type = map.get_or_add_scalar(builder, data.type_); let memory_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(scope_to_spirv(data.scope) as u32), )?; let semantics_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(semantics_to_spirv(data.semantics).bits()), )?; builder.atomic_compare_exchange( result_type.0, Some(arguments.dst.0), arguments.src1.0, memory_const.0, semantics_const.0, semantics_const.0, arguments.src3.0, arguments.src2.0, )?; } ast::Instruction::Div { data, arguments } => match data { ast::DivDetails::Unsigned(t) => { let result_type = map.get_or_add_scalar(builder, (*t).into()); builder.u_div( result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } ast::DivDetails::Signed(t) => { let result_type = map.get_or_add_scalar(builder, (*t).into()); builder.s_div( result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } ast::DivDetails::Float(t) => { let result_type = map.get_or_add_scalar(builder, t.type_.into()); builder.f_div( result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; emit_float_div_decoration(builder, arguments.dst, t.kind); } }, ast::Instruction::Sqrt { data, arguments } => { emit_sqrt(builder, map, opencl, data, arguments)?; } ast::Instruction::Rsqrt { data, arguments } => { let result_type = map.get_or_add_scalar(builder, data.type_.into()); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::rsqrt as spirv::Word, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Neg { data, arguments } => { let result_type = map.get_or_add_scalar(builder, data.type_); let negate_func = if data.type_.kind() == ast::ScalarKind::Float { dr::Builder::f_negate } else { dr::Builder::s_negate }; negate_func( builder, result_type.0, Some(arguments.dst.0), arguments.src.0, )?; } ast::Instruction::Sin { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::sin as u32, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Cos { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::cos as u32, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Lg2 { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::log2 as u32, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Ex2 { arguments, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::exp2 as u32, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Clz { data, arguments } => { let result_type = map.get_or_add_scalar(builder, (*data).into()); builder.ext_inst( result_type.0, Some(arguments.dst.0), opencl, spirv::CLOp::clz as u32, [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), )?; } ast::Instruction::Brev { data, arguments } => { let result_type = map.get_or_add_scalar(builder, (*data).into()); builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; } ast::Instruction::Popc { data, arguments } => { let result_type = map.get_or_add_scalar(builder, (*data).into()); builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; } ast::Instruction::Xor { data, arguments } => { let builder_fn: fn( &mut dr::Builder, u32, Option, u32, u32, ) -> Result = match data { ast::ScalarType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; let result_type = map.get_or_add_scalar(builder, (*data).into()); builder_fn( builder, result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } => { // Should have beeen replaced with a funciton call earlier return Err(error_unreachable()); } ast::Instruction::Rem { data, arguments } => { let builder_fn = if data.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod } else { dr::Builder::u_mod }; let result_type = map.get_or_add_scalar(builder, (*data).into()); builder_fn( builder, result_type.0, Some(arguments.dst.0), arguments.src1.0, arguments.src2.0, )?; } ast::Instruction::Prmt { data, arguments } => { let control = *data as u32; let components = [ (control >> 0) & 0b1111, (control >> 4) & 0b1111, (control >> 8) & 0b1111, (control >> 12) & 0b1111, ]; if components.iter().any(|&c| c > 7) { return Err(TranslateError::Todo); } let vec4_b8_type = map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; let dst_vector = builder.vector_shuffle( vec4_b8_type.0, None, src1_vector, src2_vector, components, )?; builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; } ast::Instruction::Membar { data } => { let (scope, semantics) = match data { ast::MemScope::Cta => ( spirv::Scope::Workgroup, spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), ast::MemScope::Gpu => ( spirv::Scope::Device, spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), ast::MemScope::Sys => ( spirv::Scope::CrossDevice, spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), ast::MemScope::Cluster => todo!(), }; let spirv_scope = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(scope as u32), )?; let spirv_semantics = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(semantics), )?; builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; } }, Statement::LoadVar(details) => { emit_load_var(builder, map, details)?; } Statement::StoreVar(details) => { let dst_ptr = match details.member_index { Some(index) => { let result_ptr_type = map.get_or_add( builder, SpirvType::pointer_to( details.typ.clone(), spirv::StorageClass::Function, ), ); let index_spirv = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(index as u32), )?; builder.in_bounds_access_chain( result_ptr_type.0, None, details.arg.src1.0, [index_spirv.0].iter().copied(), )? } None => details.arg.src1.0, }; builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; } Statement::RetValue(_, id) => { builder.ret_value(id.0)?; } Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, offset_src, }) => { let u8_pointer = map.get_or_add( builder, SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), ); let result_type = map.get_or_add( builder, SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), ); let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; let temp = builder.in_bounds_ptr_access_chain( u8_pointer.0, None, ptr_src_u8, offset_src.0, iter::empty(), )?; builder.bitcast(result_type.0, Some(dst.0), temp)?; } Statement::RepackVector(repack) => { if repack.is_extract { let scalar_type = map.get_or_add_scalar(builder, repack.typ); for (index, dst_id) in repack.unpacked.iter().enumerate() { builder.composite_extract( scalar_type.0, Some(dst_id.0), repack.packed.0, [index as u32].iter().copied(), )?; } } else { let vector_type = map.get_or_add( builder, SpirvType::Vector( SpirvScalarKey::from(repack.typ), repack.unpacked.len() as u8, ), ); let mut temp_vec = builder.undef(vector_type.0, None); for (index, src_id) in repack.unpacked.iter().enumerate() { temp_vec = builder.composite_insert( vector_type.0, None, src_id.0, temp_vec, [index as u32].iter().copied(), )?; } builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; } } } } Ok(()) } fn emit_function_linkage<'input>( builder: &mut dr::Builder, id_defs: &GlobalStringIdResolver<'input>, f: &Function, fn_name: SpirvWord, ) -> Result<(), TranslateError> { if f.linkage == ast::LinkingDirective::NONE { return Ok(()); }; let linking_name = match f.func_decl.borrow().name { // According to SPIR-V rules linkage attributes are invalid on kernels ast::MethodName::Kernel(..) => return Ok(()), ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( || match id_defs.reverse_variables.get(&fn_id) { Some(fn_name) => Ok(fn_name), None => Err(error_unknown_symbol()), }, Result::Ok, )?, }; emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); Ok(()) } fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, spirv_input: impl Iterator, spirv_output: &[ast::Variable], ) -> (SpirvWord, SpirvWord) { map.get_or_add_fn( builder, spirv_input, spirv_output .iter() .map(|var| SpirvType::new(var.v_type.clone())), ) } fn emit_linking_decoration<'input>( builder: &mut dr::Builder, id_defs: &GlobalStringIdResolver<'input>, name_override: Option<&str>, name: SpirvWord, linking: ast::LinkingDirective, ) { if linking == ast::LinkingDirective::NONE { return; } if linking.contains(ast::LinkingDirective::VISIBLE) { let string_name = name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); builder.decorate( name.0, spirv::Decoration::LinkageAttributes, [ dr::Operand::LiteralString(string_name.to_string()), dr::Operand::LinkageType(spirv::LinkageType::Export), ] .iter() .cloned(), ); } else if linking.contains(ast::LinkingDirective::EXTERN) { let string_name = name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); builder.decorate( name.0, spirv::Decoration::LinkageAttributes, [ dr::Operand::LiteralString(string_name.to_string()), dr::Operand::LinkageType(spirv::LinkageType::Import), ] .iter() .cloned(), ); } // TODO: handle LinkingDirective::WEAK } fn effective_input_arguments<'a>( this: &'a ast::MethodDeclaration<'a, SpirvWord>, ) -> impl Iterator + 'a { let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); this.input_arguments.iter().map(move |arg| { if !is_kernel && arg.state_space != ast::StateSpace::Reg { let spirv_type = SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); (arg.name, spirv_type) } else { (arg.name, SpirvType::new(arg.v_type.clone())) } }) } fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), TranslateError> { let from_parts = to_parts(&cv.from_type); let to_parts = to_parts(&cv.to_type); match (from_parts.kind, to_parts.kind, &cv.kind) { (_, _, &ConversionKind::BitToPtr) => { let dst_type = map.get_or_add( builder, SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), ); builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; } (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); if from_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; } else { builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; } } else { // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, SpirvType::new(type_from_parts(TypeParts { scalar_kind: ast::ScalarKind::Bit, ..from_parts })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; let wide_bit_type = type_from_parts(TypeParts { scalar_kind: ast::ScalarKind::Bit, ..to_parts }); let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); if to_parts.scalar_kind == ast::ScalarKind::Unsigned || to_parts.scalar_kind == ast::ScalarKind::Bit { builder.u_convert( wide_bit_type_spirv.0, Some(cv.dst.0), same_width_bit_value, )?; } else { let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed && to_parts.scalar_kind == ast::ScalarKind::Signed { dr::Builder::s_convert } else { dr::Builder::u_convert }; let wide_bit_value = conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; emit_implicit_conversion( builder, map, &ImplicitConversion { src: SpirvWord(wide_bit_value), dst: cv.dst, from_type: wide_bit_type, from_space: cv.from_space, to_type: cv.to_type.clone(), to_space: cv.to_space, kind: ConversionKind::Default, }, )?; } } } (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; } (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; } (_, _, &ConversionKind::PtrToPtr) => { let result_type = map.get_or_add( builder, SpirvType::Pointer( Box::new(SpirvType::new(cv.to_type.clone())), space_to_spirv(cv.to_space), ), ); if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic { let src = if cv.from_type != cv.to_type { let temp_type = map.get_or_add( builder, SpirvType::Pointer( Box::new(SpirvType::new(cv.to_type.clone())), space_to_spirv(cv.from_space), ), ); builder.bitcast(temp_type.0, None, cv.src.0)? } else { cv.src.0 }; builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; } else if cv.from_space == ast::StateSpace::Generic && cv.to_space != ast::StateSpace::Generic { let src = if cv.from_type != cv.to_type { let temp_type = map.get_or_add( builder, SpirvType::Pointer( Box::new(SpirvType::new(cv.to_type.clone())), space_to_spirv(cv.from_space), ), ); builder.bitcast(temp_type.0, None, cv.src.0)? } else { cv.src.0 }; builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; } else { builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; } } (_, _, &ConversionKind::AddressOf) => { let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; } (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; } (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; } _ => unreachable!(), } Ok(()) } fn vec_repr(t: T) -> Vec { let mut result = vec![0; mem::size_of::()]; unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; result } fn emit_abs( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, d: &ast::TypeFtz, arg: &ast::AbsArgs, ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.type_); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { spirv::CLOp::s_abs } else { spirv::CLOp::fabs }; builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, cl_abs as spirv::Word, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; Ok(()) } fn emit_mul_int( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, type_: ast::ScalarType, control: ast::MulIntControl, arg: &ast::MulArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(type_)); match control { ast::MulIntControl::Low => { builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; } ast::MulIntControl::High => { let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { spirv::CLOp::s_mul_hi } else { spirv::CLOp::u_mul_hi }; builder.ext_inst( inst_type.0, Some(arg.dst.0), opencl, opencl_inst as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), ] .iter() .cloned(), )?; } ast::MulIntControl::Wide => { let instr_width = type_.size_of(); let instr_kind = type_.kind(); let dst_type = scalar_from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; (src1, src2) } else { let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; (src1, src2) }; builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); } } Ok(()) } fn emit_mul_float( builder: &mut dr::Builder, map: &mut TypeWordMap, ctr: &ast::ArithFloat, arg: &ast::MulArgs, ) -> Result<(), dr::Error> { if ctr.saturate { todo!() } let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; emit_rounding_decoration(builder, arg.dst, ctr.rounding); Ok(()) } fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { match kind { ast::ScalarKind::Float => match width { 2 => ast::ScalarType::F16, 4 => ast::ScalarType::F32, 8 => ast::ScalarType::F64, _ => unreachable!(), }, ast::ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => unreachable!(), }, ast::ScalarKind::Signed => match width { 1 => ast::ScalarType::S8, 2 => ast::ScalarType::S16, 4 => ast::ScalarType::S32, 8 => ast::ScalarType::S64, _ => unreachable!(), }, ast::ScalarKind::Unsigned => match width { 1 => ast::ScalarType::U8, 2 => ast::ScalarType::U16, 4 => ast::ScalarType::U32, 8 => ast::ScalarType::U64, _ => unreachable!(), }, ast::ScalarKind::Pred => ast::ScalarType::Pred, } } fn emit_rounding_decoration( builder: &mut dr::Builder, dst: SpirvWord, rounding: Option, ) { if let Some(rounding) = rounding { builder.decorate( dst.0, spirv::Decoration::FPRoundingMode, [rounding_to_spirv(rounding)].iter().cloned(), ); } } fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { let mode = match this { ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, }; rspirv::dr::Operand::FPRoundingMode(mode) } fn emit_add_int( builder: &mut dr::Builder, map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, arg: &ast::AddArgs, ) -> Result<(), dr::Error> { if saturate { todo!() } let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; Ok(()) } fn emit_add_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, arg: &ast::AddArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); Ok(()) } fn emit_setp( builder: &mut dr::Builder, map: &mut TypeWordMap, setp: &ast::SetpData, arg: &ast::SetpArgs, ) -> Result<(), dr::Error> { let result_type = map .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) .0; let result_id = Some(arg.dst1.0); let operand_1 = arg.src1.0; let operand_2 = arg.src2.0; match setp.cmp_op { ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { builder.f_ord_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { builder.s_less_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { builder.s_greater_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { builder.f_unord_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { let temp1 = builder.is_nan(result_type, None, operand_1)?; let temp2 = builder.is_nan(result_type, None, operand_2)?; builder.logical_or(result_type, result_id, temp1, temp2) } ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { let temp1 = builder.is_nan(result_type, None, operand_1)?; let temp2 = builder.is_nan(result_type, None, operand_2)?; let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; logical_not(builder, result_type, result_id, any_nan) } _ => todo!(), }?; Ok(()) } // HACK ALERT // Temporary workaround until IGC gets its shit together // Currently IGC carries two copies of SPIRV-LLVM translator // a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. // Obviously, old and buggy one is used for compiling L0 SPIRV // https://github.com/intel/intel-graphics-compiler/issues/148 fn logical_not( builder: &mut dr::Builder, result_type: spirv::Word, result_id: Option, operand: spirv::Word, ) -> Result { let const_true = builder.constant_true(result_type, None); let const_false = builder.constant_false(result_type, None); builder.select(result_type, result_id, operand, const_false, const_true) } // HACK ALERT // For some reason IGC fails linking if the value and shift size are of different type fn insert_shift_hack( builder: &mut dr::Builder, map: &mut TypeWordMap, offset_var: spirv::Word, size_of: usize, ) -> Result { let result_type = match size_of { 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), 4 => return Ok(offset_var), _ => return Err(error_unreachable()), }; Ok(builder.u_convert(result_type.0, None, offset_var)?) } fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, dets: &ast::CvtDetails, arg: &ast::CvtArgs, ) -> Result<(), TranslateError> { match dets.mode { ptx_parser::CvtMode::SignExtend => { let cv = ImplicitConversion { src: arg.src, dst: arg.dst, from_type: dets.from.into(), from_space: ast::StateSpace::Reg, to_type: dets.to.into(), to_space: ast::StateSpace::Reg, kind: ConversionKind::SignExtend, }; emit_implicit_conversion(builder, map, &cv)?; } ptx_parser::CvtMode::ZeroExtend | ptx_parser::CvtMode::Truncate | ptx_parser::CvtMode::Bitcast => { let cv = ImplicitConversion { src: arg.src, dst: arg.dst, from_type: dets.from.into(), from_space: ast::StateSpace::Reg, to_type: dets.to.into(), to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, }; emit_implicit_conversion(builder, map, &cv)?; } ptx_parser::CvtMode::SaturateUnsignedToSigned => { let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; } ptx_parser::CvtMode::SaturateSignedToUnsigned => { let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; } ptx_parser::CvtMode::FPExtend { flush_to_zero } => { let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; } ptx_parser::CvtMode::FPTruncate { rounding, flush_to_zero, } => { let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); } ptx_parser::CvtMode::FPRound { integer_rounding, flush_to_zero, } => { if flush_to_zero == Some(true) { todo!() } let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); match integer_rounding { Some(ast::RoundingMode::NearestEven) => { builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, spirv::CLOp::rint as u32, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; } Some(ast::RoundingMode::Zero) => { builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, spirv::CLOp::trunc as u32, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; } Some(ast::RoundingMode::NegativeInf) => { builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, spirv::CLOp::floor as u32, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; } Some(ast::RoundingMode::PositiveInf) => { builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, spirv::CLOp::ceil as u32, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; } None => { builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; } } } ptx_parser::CvtMode::SignedFromFP { rounding, flush_to_zero, } => { let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); } ptx_parser::CvtMode::UnsignedFromFP { rounding, flush_to_zero, } => { let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); } ptx_parser::CvtMode::FPFromSigned(rounding) => { let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); } ptx_parser::CvtMode::FPFromUnsigned(rounding) => { let dest_t: ast::ScalarType = dets.to.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); } } Ok(()) } fn emit_mad_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, type_: ast::ScalarType, control: ast::MulIntControl, arg: &ast::MadArgs, ) -> Result<(), dr::Error> { let inst_type = map .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) .0; match control { ast::MulIntControl::Low => { let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst.0), opencl, spirv::CLOp::u_mad_hi as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), dr::Operand::IdRef(arg.src3.0), ] .iter() .cloned(), )?; } ast::MulIntControl::Wide => todo!(), }; Ok(()) } fn emit_mad_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, type_: ast::ScalarType, control: ast::MulIntControl, arg: &ast::MadArgs, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; match control { ast::MulIntControl::Low => { let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst.0), opencl, spirv::CLOp::s_mad_hi as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), dr::Operand::IdRef(arg.src3.0), ] .iter() .cloned(), )?; } ast::MulIntControl::Wide => todo!(), }; Ok(()) } fn emit_mad_float( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::ArithFloat, arg: &ast::MadArgs, ) -> Result<(), dr::Error> { let inst_type = map .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) .0; builder.ext_inst( inst_type, Some(arg.dst.0), opencl, spirv::CLOp::mad as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), dr::Operand::IdRef(arg.src3.0), ] .iter() .cloned(), )?; Ok(()) } fn emit_fma_float( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::ArithFloat, arg: &ast::FmaArgs, ) -> Result<(), dr::Error> { let inst_type = map .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) .0; builder.ext_inst( inst_type, Some(arg.dst.0), opencl, spirv::CLOp::fma as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), dr::Operand::IdRef(arg.src3.0), ] .iter() .cloned(), )?; Ok(()) } fn emit_sub_int( builder: &mut dr::Builder, map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, arg: &ast::SubArgs, ) -> Result<(), dr::Error> { if saturate { todo!() } let inst_type = map .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) .0; builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; Ok(()) } fn emit_sub_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, arg: &ast::SubArgs, ) -> Result<(), dr::Error> { let inst_type = map .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) .0; builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); Ok(()) } fn emit_min( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MinMaxDetails, arg: &ast::MinArgs, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); builder.ext_inst( inst_type.0, Some(arg.dst.0), opencl, cl_op as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), ] .iter() .cloned(), )?; Ok(()) } fn emit_max( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MinMaxDetails, arg: &ast::MaxArgs, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); builder.ext_inst( inst_type.0, Some(arg.dst.0), opencl, cl_op as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), ] .iter() .cloned(), )?; Ok(()) } fn emit_rcp( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::RcpData, arg: &ast::RcpArgs, ) -> Result<(), TranslateError> { let is_f64 = desc.type_ == ast::ScalarType::F64; let (instr_type, constant) = if is_f64 { (ast::ScalarType::F64, vec_repr(1.0f64)) } else { (ast::ScalarType::F32, vec_repr(1.0f32)) }; let result_type = map.get_or_add_scalar(builder, instr_type); let rounding = match desc.kind { ptx_parser::RcpKind::Approx => { builder.ext_inst( result_type.0, Some(arg.dst.0), opencl, spirv::CLOp::native_recip as u32, [dr::Operand::IdRef(arg.src.0)].iter().cloned(), )?; return Ok(()); } ptx_parser::RcpKind::Compliant(rounding) => rounding, }; let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; emit_rounding_decoration(builder, arg.dst, Some(rounding)); builder.decorate( arg.dst.0, spirv::Decoration::FPFastMathMode, [dr::Operand::FPFastMathMode( spirv::FPFastMathMode::ALLOW_RECIP, )] .iter() .cloned(), ); Ok(()) } fn emit_atom( builder: &mut dr::Builder, map: &mut TypeWordMap, details: &ast::AtomDetails, arg: &ast::AtomArgs, ) -> Result<(), TranslateError> { let spirv_op = match details.op { ptx_parser::AtomicOp::And => dr::Builder::atomic_and, ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { return Err(error_unreachable()) } ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, ptx_parser::AtomicOp::FloatMin => todo!(), ptx_parser::AtomicOp::FloatMax => todo!(), }; let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); let memory_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(scope_to_spirv(details.scope) as u32), )?; let semantics_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(semantics_to_spirv(details.semantics).bits()), )?; spirv_op( builder, result_type.0, Some(arg.dst.0), arg.src1.0, memory_const.0, semantics_const.0, arg.src2.0, )?; Ok(()) } fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { match this { ast::MemScope::Cta => spirv::Scope::Workgroup, ast::MemScope::Gpu => spirv::Scope::Device, ast::MemScope::Sys => spirv::Scope::CrossDevice, ptx_parser::MemScope::Cluster => todo!(), } } fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { match this { ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, } } fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { match kind { ast::DivFloatKind::Approx => { builder.decorate( dst.0, spirv::Decoration::FPFastMathMode, [dr::Operand::FPFastMathMode( spirv::FPFastMathMode::ALLOW_RECIP, )] .iter() .cloned(), ); } ast::DivFloatKind::Rounding(rnd) => { emit_rounding_decoration(builder, dst, Some(rnd)); } ast::DivFloatKind::ApproxFull => {} } } fn emit_sqrt( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, details: &ast::RcpData, a: &ast::SqrtArgs, ) -> Result<(), TranslateError> { let result_type = map.get_or_add_scalar(builder, details.type_.into()); let (ocl_op, rounding) = match details.kind { ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), }; builder.ext_inst( result_type.0, Some(a.dst.0), opencl, ocl_op as spirv::Word, [dr::Operand::IdRef(a.src.0)].iter().cloned(), )?; emit_rounding_decoration(builder, a.dst, rounding); Ok(()) } // TODO: check what kind of assembly do we emit fn emit_logical_xor_spirv( builder: &mut dr::Builder, result_type: spirv::Word, result_id: Option, op1: spirv::Word, op2: spirv::Word, ) -> Result { let temp_or = builder.logical_or(result_type, None, op1, op2)?; let temp_and = builder.logical_and(result_type, None, op1, op2)?; let temp_neg = logical_not(builder, result_type, None, temp_and)?; builder.logical_and(result_type, result_id, temp_or, temp_neg) } fn emit_load_var( builder: &mut dr::Builder, map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), _ => return Err(error_mismatched_type()), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); let vector_temp = builder.load( vector_type_spirv.0, None, details.arg.src.0, None, iter::empty(), )?; builder.composite_extract( result_type.0, Some(details.arg.dst.0), vector_temp, [index as u32].iter().copied(), )?; } Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), ); let index_spirv = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(index as u32), )?; let src = builder.in_bounds_access_chain( result_ptr_type.0, None, details.arg.src.0, [index_spirv.0].iter().copied(), )?; builder.load( result_type.0, Some(details.arg.dst.0), src, None, iter::empty(), )?; } None => { builder.load( result_type.0, Some(details.arg.dst.0), details.arg.src.0, None, iter::empty(), )?; } }; Ok(()) } fn to_parts(this: &ast::Type) -> TypeParts { match this { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), }, ast::Type::Vector(components, scalar) => TypeParts { kind: TypeKind::Vector, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], }, ast::Type::Array(_, scalar, components) => TypeParts { kind: TypeKind::Array, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), }, ast::Type::Pointer(scalar, space) => TypeParts { kind: TypeKind::Pointer, state_space: *space, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), }, } } fn type_from_parts(t: TypeParts) -> ast::Type { match t.kind { TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), TypeKind::Vector => ast::Type::Vector( t.components[0] as u8, scalar_from_parts(t.width, t.scalar_kind), ), TypeKind::Array => ast::Type::Array( None, scalar_from_parts(t.width, t.scalar_kind), t.components, ), TypeKind::Pointer => { ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) } } } #[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, scalar_kind: ast::ScalarKind, width: u8, state_space: ast::StateSpace, components: Vec, } #[derive(Eq, PartialEq, Copy, Clone)] enum TypeKind { Scalar, Vector, Array, Pointer, }