Skip to content

Commit

Permalink
refactor: EOM is instantiated when the trajectory is computed. no mor…
Browse files Browse the repository at this point in the history
…e post_init
  • Loading branch information
tmcclintock committed Nov 26, 2022
1 parent cf2f1ae commit 4a236cf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
30 changes: 14 additions & 16 deletions frispy/disc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Disc class."""

from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple, Type

import numpy as np
from scipy.integrate import solve_ivp
Expand Down Expand Up @@ -64,20 +64,7 @@ class Disc:
air_density: float = 1.225 # kg / m ^ 3
g: float = 9.81 # m / s ^ 2
model: Model = Model()
eom: Optional[EOM] = None

def __post_init__(self) -> None:
"""Sets the ``eom`` (equations of motion) attribute."""
self.eom = self.eom or EOM(
model=self.model,
area=self.area,
I_xx=self.I_xx,
I_zz=self.I_zz,
mass=self.mass,
air_density=self.air_density,
g=self.g,
)
return
eom_class: Type = EOM

def compute_trajectory(
self,
Expand Down Expand Up @@ -118,9 +105,20 @@ def compute_trajectory(
"t_eval", np.linspace(t_span[0], t_span[1], n_times)
)

# Instantiate the equations of motion
eom = self.eom_class(
model=self.model,
area=self.area,
I_xx=self.I_xx,
I_zz=self.I_zz,
mass=self.mass,
air_density=self.air_density,
g=self.g,
)

# Call the solver
result = solve_ivp(
fun=self.eom.compute_derivatives,
fun=eom.compute_derivatives,
t_span=t_span,
y0=[
self.x,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ def test_smoke():
def test_disc_has_properties():
d = Disc()
assert hasattr(d, "model")
assert hasattr(d, "eom")
assert hasattr(d, "eom_class")


def test_physical_attribute_kwarg():
d = Disc(mass=12345, area=0.1234)
assert d.mass == 12345
assert d.area == 0.1234
assert d.eom.diameter == 2 * np.sqrt(d.area / np.pi)


def test_compute_trajectory_basics():
Expand Down

0 comments on commit 4a236cf

Please sign in to comment.