diff --git a/crates/bevy_ecs/macros/src/lib.rs b/crates/bevy_ecs/macros/src/lib.rs index f3de68ccf544e..5179c0adfcff3 100644 --- a/crates/bevy_ecs/macros/src/lib.rs +++ b/crates/bevy_ecs/macros/src/lib.rs @@ -20,10 +20,12 @@ use syn::{ enum BundleFieldKind { Component, Ignore, + Require, } const BUNDLE_ATTRIBUTE_NAME: &str = "bundle"; const BUNDLE_ATTRIBUTE_IGNORE_NAME: &str = "ignore"; +const BUNDLE_ATTRIBUTE_REQUIRE_NAME: &str = "require"; #[proc_macro_derive(Bundle, attributes(bundle))] pub fn derive_bundle(input: TokenStream) -> TokenStream { @@ -45,6 +47,9 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { if path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) { field_kind.push(BundleFieldKind::Ignore); continue 'field_loop; + } else if path.is_ident(BUNDLE_ATTRIBUTE_REQUIRE_NAME) { + field_kind.push(BundleFieldKind::Require); + continue 'field_loop; } return syn::Error::new( @@ -77,6 +82,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { let mut field_component_ids = Vec::new(); let mut field_get_components = Vec::new(); let mut field_from_components = Vec::new(); + let mut field_requires = Vec::new(); for ((field_type, field_kind), field) in field_type.iter().zip(field_kind.iter()).zip(field.iter()) { @@ -98,6 +104,15 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { #field: ::std::default::Default::default(), }); } + + BundleFieldKind::Require => { + field_requires.push(quote! { + <#field_type as #ecs_path::bundle::Bundle>::requires(components, storages, &mut *ids); + }); + field_from_components.push(quote! { + #field: #ecs_path::bundle::Require::default(), + }); + } } } let generics = ast.generics; @@ -127,6 +142,13 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { #(#field_from_components)* } } + + fn requires( + components: &mut #ecs_path::component::Components, + storages: &mut #ecs_path::storage::Storages, + ids: &mut impl FnMut(#ecs_path::component::ComponentId)) { + #(#field_requires)* + } } impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause { diff --git a/crates/bevy_ecs/src/bundle.rs b/crates/bevy_ecs/src/bundle.rs index 8ddf6e6ab9065..20200503e0284 100644 --- a/crates/bevy_ecs/src/bundle.rs +++ b/crates/bevy_ecs/src/bundle.rs @@ -18,7 +18,7 @@ use crate::{ }; use bevy_ptr::OwningPtr; use bevy_utils::all_tuples; -use std::any::TypeId; +use std::{any::TypeId, marker::PhantomData}; /// The `Bundle` trait enables insertion and removal of [`Component`]s from an entity. /// @@ -160,6 +160,14 @@ pub unsafe trait Bundle: DynamicBundle + Send + Sync + 'static { // Ensure that the `OwningPtr` is used correctly F: for<'a> FnMut(&'a mut T) -> OwningPtr<'a>, Self: Sized; + + #[doc(hidden)] + fn requires( + _components: &mut Components, + _storages: &mut Storages, + _ids: &mut impl FnMut(ComponentId), + ) { + } } /// The parts from [`Bundle`] that don't require statically knowing the components of the bundle. @@ -269,6 +277,49 @@ impl SparseSetIndex for BundleId { } } +#[allow(unused)] +pub struct Require(PhantomData); + +impl Default for Require { + fn default() -> Self { + Self(PhantomData) + } +} + +// SAFETY: +// - `Bundle::component_ids` does nothing since `Require` doesn't have any components. +// - `Bundle::get_components` does nothing since `Require` doesn't have any components. +// - `Bundle::from_components` just returns `Require::default()` since `Require` doesn't have any components. +unsafe impl Bundle for Require { + fn component_ids( + _components: &mut Components, + _storages: &mut Storages, + _ids: &mut impl FnMut(ComponentId), + ) { + } + + unsafe fn from_components(_ctx: &mut S, _func: &mut F) -> Self + where + // Ensure that the `OwningPtr` is used correctly + F: for<'a> FnMut(&'a mut S) -> OwningPtr<'a>, + Self: Sized, + { + Self::default() + } + + fn requires( + components: &mut Components, + storages: &mut Storages, + ids: &mut impl FnMut(ComponentId), + ) { + ids(components.init_component::(storages)); + } +} + +impl DynamicBundle for Require { + fn get_components(self, _func: &mut impl FnMut(StorageType, OwningPtr<'_>)) {} +} + pub struct BundleInfo { id: BundleId, // SAFETY: Every ID in this list must be valid within the World that owns the BundleInfo, @@ -289,6 +340,7 @@ impl BundleInfo { bundle_type_name: &'static str, components: &Components, component_ids: Vec, + requires: Vec, id: BundleId, ) -> BundleInfo { let mut deduped = component_ids.clone(); @@ -317,6 +369,15 @@ impl BundleInfo { panic!("Bundle {bundle_type_name} has duplicate components: {names}"); } + for id in requires { + if !component_ids.contains(&id) { + panic!( + "Bundle {bundle_type_name} requires missing component: {}", + components.get_info_unchecked(id).name(), + ); + } + } + // SAFETY: The caller ensures that component_ids: // - is valid for the associated world // - has had its storage initialized @@ -813,14 +874,16 @@ impl Bundles { let bundle_infos = &mut self.bundle_infos; let id = self.bundle_ids.entry(TypeId::of::()).or_insert_with(|| { let mut component_ids = Vec::new(); + let mut requires = Vec::new(); T::component_ids(components, storages, &mut |id| component_ids.push(id)); + T::requires(components, storages,&mut |id| requires.push(id)); let id = BundleId(bundle_infos.len()); let bundle_info = // SAFETY: T::component_id ensures its: // - info was created // - appropriate storage for it has been initialized. // - was created in the same order as the components in T - unsafe { BundleInfo::new(std::any::type_name::(), components, component_ids, id) }; + unsafe { BundleInfo::new(std::any::type_name::(), components, component_ids, requires, id) }; bundle_infos.push(bundle_info); id }); @@ -906,7 +969,7 @@ fn initialize_dynamic_bundle( let id = BundleId(bundle_infos.len()); let bundle_info = // SAFETY: `component_ids` are valid as they were just checked - unsafe { BundleInfo::new("", components, component_ids, id) }; + unsafe { BundleInfo::new("", components, component_ids, vec![], id) }; bundle_infos.push(bundle_info); (id, storage_types) diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 67a65ee64e7c0..c22a30bfdef16 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -1708,4 +1708,43 @@ mod tests { "new entity was spawned and received C component" ); } + + #[test] + #[should_panic] + fn bundle_with_missing_require() { + use super::bundle::Require; + #[derive(Component)] + struct A; + + #[derive(Bundle, Default)] + struct B { + #[bundle(require)] + #[allow(unused)] + a: Require, + } + + let mut world = World::default(); + + world.spawn(B::default()); + } + + #[test] + fn bundle_with_require() { + use super::bundle::Require; + #[derive(Component)] + struct A; + + #[derive(Bundle, Default)] + struct B { + #[bundle(require)] + #[allow(unused)] + a: Require, + } + + let mut world = World::default(); + + let entity = world.spawn((A, B::default())); + + assert!(entity.contains::()); + } }