diff --git a/compiler/rustc_target/src/callconv/nvptx64.rs b/compiler/rustc_target/src/callconv/nvptx64.rs index 2e8b16d3a938e..c64164372a11d 100644 --- a/compiler/rustc_target/src/callconv/nvptx64.rs +++ b/compiler/rustc_target/src/callconv/nvptx64.rs @@ -1,5 +1,5 @@ use super::{ArgAttribute, ArgAttributes, ArgExtension, CastTarget}; -use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform}; +use crate::abi::call::{ArgAbi, FnAbi, Reg, Size, Uniform}; use crate::abi::{HasDataLayout, TyAbiInterface}; fn classify_ret(ret: &mut ArgAbi<'_, Ty>) { @@ -53,21 +53,37 @@ where Ty: TyAbiInterface<'a, C> + Copy, C: HasDataLayout, { - if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) { - let align_bytes = arg.layout.align.abi.bytes(); + match arg.mode { + super::PassMode::Ignore | super::PassMode::Direct(_) => return, + super::PassMode::Pair(_, _) => {} + super::PassMode::Cast { .. } => unreachable!(), + super::PassMode::Indirect { .. } => {} + } + + // FIXME only allow structs and wide pointers here + // panic!( + // "`extern \"ptx-kernel\"` doesn't allow passing types other than primitives and structs" + // ); + + let align_bytes = arg.layout.align.abi.bytes(); - let unit = match align_bytes { - 1 => Reg::i8(), - 2 => Reg::i16(), - 4 => Reg::i32(), - 8 => Reg::i64(), - 16 => Reg::i128(), - _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), - }; - arg.cast_to(Uniform::new(unit, Size::from_bytes(2 * align_bytes))); + let unit = match align_bytes { + 1 => Reg::i8(), + 2 => Reg::i16(), + 4 => Reg::i32(), + 8 => Reg::i64(), + 16 => Reg::i128(), + _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), + }; + if arg.layout.size.bytes() / align_bytes == 1 { + // Make sure we pass the struct as array at the LLVM IR level and not as a single integer. + arg.cast_to(CastTarget { + prefix: [Some(unit), None, None, None, None, None, None, None], + rest: Uniform::new(unit, Size::ZERO), + attrs: ArgAttributes::new(), + }); } else { - // FIXME: find a better way to do this. See https://github.com/rust-lang/rust/issues/117271. - arg.make_direct_deprecated(); + arg.cast_to(Uniform::new(unit, arg.layout.size)); } } diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index c528179ae0e7a..169f3a78c26a7 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -489,21 +489,16 @@ fn fn_abi_sanity_check<'tcx>( // have to allow it -- but we absolutely shouldn't let any more targets do // that. (Also see .) // - // The unstable abi `PtxKernel` also uses Direct for now. - // It needs to switch to something else before stabilization can happen. - // (See issue: https://github.com/rust-lang/rust/issues/117271) - // - // And finally the unadjusted ABI is ill specified and uses Direct for all - // args, but unfortunately we need it for calling certain LLVM intrinsics. + // The unadjusted ABI also uses Direct for all args and is ill-specified, + // but unfortunately we need it for calling certain LLVM intrinsics. match spec_abi { ExternAbi::Unadjusted => {} - ExternAbi::PtxKernel => {} ExternAbi::C { unwind: _ } if matches!(&*tcx.sess.target.arch, "wasm32" | "wasm64") => {} _ => { panic!( - "`PassMode::Direct` for aggregates only allowed for \"unadjusted\" and \"ptx-kernel\" functions and on wasm\n\ + "`PassMode::Direct` for aggregates only allowed for \"unadjusted\" functions and on wasm\n\ Problematic type: {:#?}", arg.layout, ); diff --git a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs index fb3a325a41f81..b3bfc66a5a570 100644 --- a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs +++ b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs @@ -242,22 +242,6 @@ pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {} //pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {} // CHECK: .visible .entry f_u32_slice_arg( -// CHECK: .param .u64 f_u32_slice_arg_param_0 -// CHECK: .param .u64 f_u32_slice_arg_param_1 +// CHECK: .param .align 8 .b8 f_u32_slice_arg_param_0[16] #[no_mangle] pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {} - -// CHECK: .visible .entry f_tuple_u8_u8_arg( -// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {} - -// CHECK: .visible .entry f_tuple_u32_u32_arg( -// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {} - -// CHECK: .visible .entry f_tuple_u8_u8_u32_arg( -// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {}