Skip to content

Commit

Permalink
Specify sensitivity order in prepared amici functions (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
dilpath authored May 2, 2023
1 parent f7fd881 commit e9e7ba3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 92 deletions.
115 changes: 27 additions & 88 deletions fiddy/extensions/amici/amici.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def run_amici_simulation_to_cached_functions(
parameter_ids: List[str] = None,
amici_solver: amici.AmiciModel = None,
amici_edata: amici.AmiciExpData = None,
#run_amici_simulation: Callable[[Any], amici.AmiciReturnData] = None,
derivative_variables: List[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -182,27 +181,24 @@ def unravel_derivatives(raveled_derivatives: Type.DERIVATIVE, derivative_shapes=
for k in derivative_variables
}

def run_amici_simulation(point: Type.POINT):
def run_amici_simulation(point: Type.POINT, order: amici.SensitivityOrder):
problem_parameters = dict(zip(parameter_ids, point))
amici_model.setParameterById(problem_parameters)
amici_solver.setSensitivityOrder(order)
rdata = amici.runAmiciSimulation(model=amici_model, solver=amici_solver, edata=amici_edata)
return rdata

if cache:
run_amici_simulation = CachedFunction(run_amici_simulation)

def function(point: Type.POINT):
rdata = run_amici_simulation(point=point)
rdata = run_amici_simulation(point=point, order=amici.SensitivityOrder.none)
outputs = {
variable: fiddy_array(getattr(rdata, variable))
for variable in chosen_derivatives
}
rdata_flat = np.concatenate([output.flat for output in outputs.values()])
# np.concatenate([rdata.x, rdata.y])
return rdata_flat

def derivative(point: Type.POINT, return_dict: bool = False):
rdata = run_amici_simulation(point=point)
rdata = run_amici_simulation(point=point, order=amici.SensitivityOrder.first)
outputs = {
variable: rdata_array_transpose(array=fiddy_array(getattr(rdata, derivative_variable)), variable=derivative_variable)
for variable, derivative_variable in chosen_derivatives.items()
Expand All @@ -218,44 +214,13 @@ def derivative(point: Type.POINT, return_dict: bool = False):
return outputs
return rdata_flat

#def function(point: Type.POINT):
# output = simulate_petab_full_cached(point)
# result = output[LLH]
# return np.array(result)

#def derivative(point: Type.POINT) -> Type.POINT:
# result = simulate_petab_full_cached(point)
# #sllh = np.array(
# # [
# # gradient_transformations[parameter_index](
# # gradient_value=result[SLLH][parameter_id],
# # parameter_value=point[parameter_index],
# # )
# # for parameter_index, parameter_id in enumerate(parameter_ids)
# # ]
# #)
# sllh = np.array([result[SLLH][parameter_id] for parameter_id in parameter_ids])
# return sllh
if cache:
function = CachedFunction(function)
derivative = CachedFunction(derivative)

# Get structure
dummy_point = fiddy_array(amici_model.getParameters())
dummy_rdata = run_amici_simulation(point=dummy_point)
#dummy_function_output = function(dummy_point)
#dummy_derivative_output = derivative(dummy_point)
#structure_function = {
# variable: (
# fiddy_array(getattr(dummy_rdata, variable)).size,
# fiddy_array(getattr(dummy_rdata, variable)).shape,
# )
# for variable in chosen_derivatives
#}
#structure_derivative = {
# variable: (
# fiddy_array(getattr(dummy_rdata, derivative_variable)).size,
# fiddy_array(getattr(dummy_rdata, derivative_variable)).shape,
# )
# for variable, derivative_variable in chosen_derivatives.items()
#}
dummy_rdata = run_amici_simulation(point=dummy_point, order=amici.SensitivityOrder.first)

structures = {
'function': {variable: None for variable in chosen_derivatives},
Expand Down Expand Up @@ -293,12 +258,11 @@ def reshape(array: Type.ARRAY, structure: TYPE_STRUCTURE) -> Dict[str, Type.ARRA

def simulate_petab_to_cached_functions(
petab_problem: petab.Problem,
*args,
amici_model: amici.Model,
parameter_ids: List[str] = None,
cache: bool = True,
precreate_edatas: bool = True,
precreate_parameter_mapping: bool = True,
scaled_gradients: bool = False,
simulate_petab: Callable[[Any], Dict[str, Any]] = None,
**kwargs,
) -> Tuple[Type.FUNCTION, Type.FUNCTION]:
Expand All @@ -324,9 +288,7 @@ def simulate_petab_to_cached_functions(
precreate_parameter_mapping:
Whether to create the AMICI parameter mapping object in advance, to
save time.
scaled_gradients:
Whether to return gradients on the scale of the parameters.
\*args, \*\*kwargs:
\*\*kwargs:
Passed to `simulate_petab`.
Returns:
Expand All @@ -339,35 +301,18 @@ def simulate_petab_to_cached_functions(

if simulate_petab is None:
simulate_petab = amici.petab_objective.simulate_petab
#if scaled_gradients:
# gradient_transformations = [
# transforms[
# petab_problem.parameter_df.loc[parameter_id, PARAMETER_SCALE]
# ]
# for parameter_id in parameter_ids
# ]
#else:
# gradient_transformations = [transforms[LIN] for _ in parameter_ids]

edatas = None
if precreate_edatas:
if 'amici_model' not in kwargs:
raise ValueError(
'Please supply the AMICI model to precreate ExpData.'
)
edatas = create_edatas(
amici_model=kwargs['amici_model'],
amici_model=amici_model,
petab_problem=petab_problem,
simulation_conditions=\
petab_problem.get_simulation_conditions_from_measurement_df(),
)

parameter_mapping = None
if precreate_parameter_mapping:
if 'amici_model' not in kwargs:
raise ValueError(
'Please supply the AMICI model to precreate ExpData.'
)
parameter_mapping = create_parameter_mapping(
petab_problem=petab_problem,
simulation_conditions=\
Expand All @@ -380,7 +325,7 @@ def simulate_petab_to_cached_functions(
.default
),
),
amici_model=kwargs['amici_model'],
amici_model=amici_model,
)

precreated_kwargs = {
Expand All @@ -394,46 +339,40 @@ def simulate_petab_to_cached_functions(
if v is not None
}

amici_solver = kwargs.pop('solver', amici_model.getSolver())

simulate_petab_partial = partial(
simulate_petab,
*args,
scaled_parameters=scaled_gradients,
amici_model=amici_model,
**precreated_kwargs,
**kwargs,
)

def simulate_petab_full(point: Type.POINT):
def simulate_petab_full(point: Type.POINT, order: amici.SensitivityOrder):
problem_parameters = dict(zip(parameter_ids, point))
result = simulate_petab_partial(problem_parameters=problem_parameters)
amici_solver.setSensitivityOrder(order)
result = simulate_petab_partial(
problem_parameters=problem_parameters,
solver=amici_solver,
)
return result

simulate_petab_full_cached = simulate_petab_full
if cache:
simulate_petab_full_cached = CachedFunction(simulate_petab_full)

def function(point: Type.POINT):
output = simulate_petab_full_cached(point)
output = simulate_petab_full(point, order=amici.SensitivityOrder.none)
result = output[LLH]
return np.array(result)

def derivative(point: Type.POINT) -> Type.POINT:
result = simulate_petab_full_cached(point)

#sllh = np.array(
# [
# gradient_transformations[parameter_index](
# gradient_value=result[SLLH][parameter_id],
# parameter_value=point[parameter_index],
# )
# for parameter_index, parameter_id in enumerate(parameter_ids)
# ]
#)
result = simulate_petab_full(point, order=amici.SensitivityOrder.first)
sllh = np.array([result[SLLH][parameter_id] for parameter_id in parameter_ids])
return sllh

return function, derivative

if cache:
function = CachedFunction(function)
derivative = CachedFunction(derivative)

#class SplitAmiciReturnData(Analysis):
# """Split AMICI output into multiple outputs (e.g. state variables `x` and observable variables `y`)."""
# pass
return function, derivative
8 changes: 4 additions & 4 deletions tests/extensions/amici/test_amici.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ def test_run_amici_simulation_to_functions(problem_generator):
#analysis_classes=[
# lambda: TransformByDirectionScale(scales=parameter_scales),
#],
success_checker=Consistency(),
success_checker=Consistency(atol=1e-2),
)
test_derivative = derivative.value

# The test derivative is close to the expected derivative.
assert np.isclose(test_derivative, expected_derivative, rtol=1e-1, equal_nan=True).all()
assert np.isclose(test_derivative, expected_derivative, rtol=1e-1, atol=1e-1, equal_nan=True).all()

# Same as above assert.
check = NumpyIsCloseDerivativeCheck(
derivative=derivative,
expectation=expected_derivative,
point=point,
)
result = check(rtol=1e-1, equal_nan=True)
result = check(rtol=1e-1, atol=1e-1, equal_nan=True)
assert result.success


Expand All @@ -135,7 +135,7 @@ def test_simulate_petab_to_functions(problem_generator):
derivative = get_derivative(
function=amici_function,
point=point,
sizes=[1e-10, 1e-5],
sizes=[1e-10, 1e-5, 1e-3, 1e-1],
direction_ids=parameter_ids,
method_ids=[MethodId.FORWARD, MethodId.BACKWARD, MethodId.CENTRAL],
analysis_classes=[
Expand Down

0 comments on commit e9e7ba3

Please sign in to comment.