diff --git a/frispy/disc.py b/frispy/disc.py index 23c9569..71f35d9 100644 --- a/frispy/disc.py +++ b/frispy/disc.py @@ -1,7 +1,7 @@ """Disc class.""" from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple, Type import numpy as np from scipy.integrate import solve_ivp @@ -64,20 +64,7 @@ class Disc: air_density: float = 1.225 # kg / m ^ 3 g: float = 9.81 # m / s ^ 2 model: Model = Model() - eom: Optional[EOM] = None - - def __post_init__(self) -> None: - """Sets the ``eom`` (equations of motion) attribute.""" - self.eom = self.eom or EOM( - model=self.model, - area=self.area, - I_xx=self.I_xx, - I_zz=self.I_zz, - mass=self.mass, - air_density=self.air_density, - g=self.g, - ) - return + eom_class: Type = EOM def compute_trajectory( self, @@ -118,9 +105,20 @@ def compute_trajectory( "t_eval", np.linspace(t_span[0], t_span[1], n_times) ) + # Instantiate the equations of motion + eom = self.eom_class( + model=self.model, + area=self.area, + I_xx=self.I_xx, + I_zz=self.I_zz, + mass=self.mass, + air_density=self.air_density, + g=self.g, + ) + # Call the solver result = solve_ivp( - fun=self.eom.compute_derivatives, + fun=eom.compute_derivatives, t_span=t_span, y0=[ self.x, diff --git a/tests/test_disc.py b/tests/test_disc.py index 2ffd49b..d1bcf45 100644 --- a/tests/test_disc.py +++ b/tests/test_disc.py @@ -11,14 +11,13 @@ def test_smoke(): def test_disc_has_properties(): d = Disc() assert hasattr(d, "model") - assert hasattr(d, "eom") + assert hasattr(d, "eom_class") def test_physical_attribute_kwarg(): d = Disc(mass=12345, area=0.1234) assert d.mass == 12345 assert d.area == 0.1234 - assert d.eom.diameter == 2 * np.sqrt(d.area / np.pi) def test_compute_trajectory_basics():