Skip to content

Commit

Permalink
feat: support some differential transformations
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Feb 20, 2024
1 parent ac4213e commit e809368
Showing 1 changed file with 114 additions and 5 deletions.
119 changes: 114 additions & 5 deletions src/vector/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,130 @@

__all__ = ["represent_as"]

from dataclasses import fields
from typing import Any
from warnings import warn

import array_api_jax_compat as xp
import astropy.units as u
import jax
from jax_quantity import Quantity
from plum import dispatch

from ._d1.builtin import Cartesian1DVector, RadialVector
from ._d2.base import Abstract2DVector # pylint: disable=cyclic-import
from ._d2.builtin import Cartesian2DVector, PolarVector
from ._d3.base import Abstract3DVector
from ._d3.builtin import Cartesian3DVector, CylindricalVector, SphericalVector
from ._base import AbstractVector, AbstractVectorDifferential
from ._d1.base import Abstract1DVectorDifferential
from ._d1.builtin import (
Cartesian1DVector,
CartesianDifferential1D,
RadialDifferential,
RadialVector,
)
from ._d2.base import Abstract2DVector, Abstract2DVectorDifferential
from ._d2.builtin import (
Cartesian2DVector,
CartesianDifferential2D,
PolarDifferential,
PolarVector,
)
from ._d3.base import Abstract3DVector, Abstract3DVectorDifferential
from ._d3.builtin import (
Cartesian3DVector,
CartesianDifferential3D,
CylindricalDifferential,
CylindricalVector,
SphericalDifferential,
SphericalVector,
)
from ._exceptions import IrreversibleDimensionChange


# TODO: implement for cross-representations
@dispatch.multi(
# N-D -> N-D
(
Abstract1DVectorDifferential,
type[Abstract1DVectorDifferential],
AbstractVector,
),
(
Abstract2DVectorDifferential,
type[Abstract2DVectorDifferential],
AbstractVector,
),
(
Abstract3DVectorDifferential,
type[Abstract3DVectorDifferential],
AbstractVector,
),
)
def represent_as(
current: AbstractVectorDifferential,
target: type[AbstractVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> AbstractVectorDifferential:
"""Abstract3DVectorDifferential -> Cartesian -> Abstract3DVectorDifferential.
This is the base case for the transformation of 1D vector differentials.
"""
# Start by transforming the position to the type required by the
# differential to construct the Jacobian.
current_position = represent_as(position, current.vector_cls, **kwargs)

# Takes the Jacobian through the representation transformation function. This
# returns a representation of the target type, where the value of each field the
# corresponding row of the Jacobian. The value of the field is a Quantity with
# the correct numerator unit (of the Jacobian row). The value is a Vector of the
# original type, with fields that are the columns of that row, but with only the
# denomicator's units.
jac_nested_vecs = jax.jacfwd(represent_as)(current_position, target.vector_cls)

# This changes the Jacobian to be a dictionary of each row, with the value being
# that row's column as a Vector of the original type, now with the correct units
# for each element.
jac_rows = {
f"d_{f.name}": (v := getattr(jac_nested_vecs, f.name)).unit / v.value
for f in fields(jac_nested_vecs)
}

# Now we can use the Jacobian to transform the differential.
return target(
**{ # Each field is the dot product of the row of the J and the diff.
k: xp.sum( # Doing the dot product.
xp.asarray(
[
getattr(j_q, f.name) * getattr(current, f"d_{f.name}")
for f in fields(j_q)
]
),
)
for k, j_q in jac_rows.items()
}
)


# Self transform
@dispatch.multi(
(CartesianDifferential1D, type[CartesianDifferential1D], AbstractVector),
(RadialDifferential, type[RadialDifferential], AbstractVector),
(CartesianDifferential2D, type[CartesianDifferential2D], AbstractVector),
(PolarDifferential, type[PolarDifferential], AbstractVector),
(CartesianDifferential3D, type[CartesianDifferential3D], AbstractVector),
(SphericalDifferential, type[SphericalDifferential], AbstractVector),
(CylindricalDifferential, type[CylindricalDifferential], AbstractVector),
)
def represent_as(
current: AbstractVectorDifferential,
target: type[AbstractVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> AbstractVectorDifferential:
"""Self transform."""
return current


###############################################################################
# 1D

Expand Down

0 comments on commit e809368

Please sign in to comment.