Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept for Required Components: Require<T> #8557

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions crates/bevy_ecs/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use syn::{
enum BundleFieldKind {
Component,
Ignore,
Require,
}

const BUNDLE_ATTRIBUTE_NAME: &str = "bundle";
const BUNDLE_ATTRIBUTE_IGNORE_NAME: &str = "ignore";
const BUNDLE_ATTRIBUTE_REQUIRE_NAME: &str = "require";

#[proc_macro_derive(Bundle, attributes(bundle))]
pub fn derive_bundle(input: TokenStream) -> TokenStream {
Expand All @@ -45,6 +47,9 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
if path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
field_kind.push(BundleFieldKind::Ignore);
continue 'field_loop;
} else if path.is_ident(BUNDLE_ATTRIBUTE_REQUIRE_NAME) {
field_kind.push(BundleFieldKind::Require);
continue 'field_loop;
}

return syn::Error::new(
Expand Down Expand Up @@ -77,6 +82,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
let mut field_component_ids = Vec::new();
let mut field_get_components = Vec::new();
let mut field_from_components = Vec::new();
let mut field_requires = Vec::new();
for ((field_type, field_kind), field) in
field_type.iter().zip(field_kind.iter()).zip(field.iter())
{
Expand All @@ -98,6 +104,15 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
#field: ::std::default::Default::default(),
});
}

BundleFieldKind::Require => {
field_requires.push(quote! {
<#field_type as #ecs_path::bundle::Bundle>::requires(components, storages, &mut *ids);
});
field_from_components.push(quote! {
#field: #ecs_path::bundle::Require::default(),
});
}
}
}
let generics = ast.generics;
Expand Down Expand Up @@ -127,6 +142,13 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
#(#field_from_components)*
}
}

fn requires(
components: &mut #ecs_path::component::Components,
storages: &mut #ecs_path::storage::Storages,
ids: &mut impl FnMut(#ecs_path::component::ComponentId)) {
#(#field_requires)*
}
}

impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause {
Expand Down
69 changes: 66 additions & 3 deletions crates/bevy_ecs/src/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
};
use bevy_ptr::OwningPtr;
use bevy_utils::all_tuples;
use std::any::TypeId;
use std::{any::TypeId, marker::PhantomData};

/// The `Bundle` trait enables insertion and removal of [`Component`]s from an entity.
///
Expand Down Expand Up @@ -160,6 +160,14 @@ pub unsafe trait Bundle: DynamicBundle + Send + Sync + 'static {
// Ensure that the `OwningPtr` is used correctly
F: for<'a> FnMut(&'a mut T) -> OwningPtr<'a>,
Self: Sized;

#[doc(hidden)]
fn requires(
_components: &mut Components,
_storages: &mut Storages,
_ids: &mut impl FnMut(ComponentId),
) {
}
}

/// The parts from [`Bundle`] that don't require statically knowing the components of the bundle.
Expand Down Expand Up @@ -269,6 +277,49 @@ impl SparseSetIndex for BundleId {
}
}

#[allow(unused)]
pub struct Require<T: Component>(PhantomData<T>);

impl<T: Component> Default for Require<T> {
fn default() -> Self {
Self(PhantomData)
}
}

// SAFETY:
// - `Bundle::component_ids` does nothing since `Require` doesn't have any components.
// - `Bundle::get_components` does nothing since `Require` doesn't have any components.
// - `Bundle::from_components` just returns `Require::default()` since `Require` doesn't have any components.
unsafe impl<T: Component> Bundle for Require<T> {
fn component_ids(
_components: &mut Components,
_storages: &mut Storages,
_ids: &mut impl FnMut(ComponentId),
) {
}

unsafe fn from_components<S, F>(_ctx: &mut S, _func: &mut F) -> Self
where
// Ensure that the `OwningPtr` is used correctly
F: for<'a> FnMut(&'a mut S) -> OwningPtr<'a>,
Self: Sized,
{
Self::default()
}

fn requires(
components: &mut Components,
storages: &mut Storages,
ids: &mut impl FnMut(ComponentId),
) {
ids(components.init_component::<T>(storages));
}
}

impl<T: Component> DynamicBundle for Require<T> {
fn get_components(self, _func: &mut impl FnMut(StorageType, OwningPtr<'_>)) {}
}

pub struct BundleInfo {
id: BundleId,
// SAFETY: Every ID in this list must be valid within the World that owns the BundleInfo,
Expand All @@ -289,6 +340,7 @@ impl BundleInfo {
bundle_type_name: &'static str,
components: &Components,
component_ids: Vec<ComponentId>,
requires: Vec<ComponentId>,
id: BundleId,
) -> BundleInfo {
let mut deduped = component_ids.clone();
Expand Down Expand Up @@ -317,6 +369,15 @@ impl BundleInfo {
panic!("Bundle {bundle_type_name} has duplicate components: {names}");
}

for id in requires {
if !component_ids.contains(&id) {
panic!(
"Bundle {bundle_type_name} requires missing component: {}",
components.get_info_unchecked(id).name(),
);
}
}

// SAFETY: The caller ensures that component_ids:
// - is valid for the associated world
// - has had its storage initialized
Expand Down Expand Up @@ -813,14 +874,16 @@ impl Bundles {
let bundle_infos = &mut self.bundle_infos;
let id = self.bundle_ids.entry(TypeId::of::<T>()).or_insert_with(|| {
let mut component_ids = Vec::new();
let mut requires = Vec::new();
T::component_ids(components, storages, &mut |id| component_ids.push(id));
T::requires(components, storages,&mut |id| requires.push(id));
let id = BundleId(bundle_infos.len());
let bundle_info =
// SAFETY: T::component_id ensures its:
// - info was created
// - appropriate storage for it has been initialized.
// - was created in the same order as the components in T
unsafe { BundleInfo::new(std::any::type_name::<T>(), components, component_ids, id) };
unsafe { BundleInfo::new(std::any::type_name::<T>(), components, component_ids, requires, id) };
bundle_infos.push(bundle_info);
id
});
Expand Down Expand Up @@ -906,7 +969,7 @@ fn initialize_dynamic_bundle(
let id = BundleId(bundle_infos.len());
let bundle_info =
// SAFETY: `component_ids` are valid as they were just checked
unsafe { BundleInfo::new("<dynamic bundle>", components, component_ids, id) };
unsafe { BundleInfo::new("<dynamic bundle>", components, component_ids, vec![], id) };
bundle_infos.push(bundle_info);

(id, storage_types)
Expand Down
39 changes: 39 additions & 0 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1708,4 +1708,43 @@ mod tests {
"new entity was spawned and received C component"
);
}

#[test]
#[should_panic]
fn bundle_with_missing_require() {
use super::bundle::Require;
#[derive(Component)]
struct A;

#[derive(Bundle, Default)]
struct B {
#[bundle(require)]
#[allow(unused)]
a: Require<A>,
}

let mut world = World::default();

world.spawn(B::default());
}

#[test]
fn bundle_with_require() {
use super::bundle::Require;
#[derive(Component)]
struct A;

#[derive(Bundle, Default)]
struct B {
#[bundle(require)]
#[allow(unused)]
a: Require<A>,
}

let mut world = World::default();

let entity = world.spawn((A, B::default()));

assert!(entity.contains::<A>());
}
}