Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enable MKL-DNN FullyConnected backward #17318

Merged
merged 27 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0b33c8c
fix mkldnn fc bwd bug due to data inplace
rongzha1 Nov 22, 2019
2d74cc4
enable mkldnn fc bwd
rongzha1 Nov 25, 2019
f4e6557
Merge commit 'refs/pull/16890/head' of https://github.com/apache/incu…
TaoLv Jan 15, 2020
db88a23
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Jan 16, 2020
6a98aa4
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Jan 17, 2020
e17f70d
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Jan 20, 2020
9c7596b
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 5, 2020
d015609
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 9, 2020
a103baa
fix cpp tests
TaoLv Feb 9, 2020
caceb9b
try: fix random seed
TaoLv Feb 9, 2020
7bbbc7f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 10, 2020
1bf97fa
fix cpp test
TaoLv Feb 10, 2020
b374889
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 11, 2020
2393889
loose rtol for fc cpp test
TaoLv Feb 11, 2020
8a01fef
improve error message
TaoLv Feb 11, 2020
979df7a
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 12, 2020
5525f71
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 13, 2020
2200653
limit the max value of test tensors
TaoLv Feb 13, 2020
faa1db8
fix lint
TaoLv Feb 13, 2020
2a3ebb4
limit max value for mkldnn tensors
TaoLv Feb 13, 2020
34549f2
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 14, 2020
93eb151
remove fixed random seed
TaoLv Feb 14, 2020
38070d9
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 16, 2020
56d873f
address review comments
TaoLv Feb 16, 2020
43905c3
Revert "address review comments"
TaoLv Feb 17, 2020
edef1c7
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 17, 2020
464204f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
TaoLv Feb 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp
// Will be fixed when we decide to enable the backward of FC.
bool mkldnn_fc_backward_enable = false;
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down Expand Up @@ -232,12 +229,10 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
Expand Down
36 changes: 18 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,24 +273,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));

CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
Expand Down Expand Up @@ -319,6 +301,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
MKLDNNStream::Get()->Submit();
}

Expand Down
44 changes: 22 additions & 22 deletions tests/cpp/include/test_mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,24 @@ struct TestArrayShapes {
};

// Init arrays with the default layout.
inline static void InitDefaultArray(NDArray *arr, bool is_rand = false) {
inline static void InitDefaultArray(NDArray *arr, bool is_rand = false, int max = 50) {
const TBlob &blob = arr->data();
mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>();
int size = blob.Size();

for (int i = 0; i < size; i++)
if (is_rand) {
data[i] = (std::rand() % 100) - 50;
data[i] = (std::rand() % (max * 2)) - max;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about change to data[i] = std::rand() * 1.0 / RAND_MAX - 0.5;? As max = 1 will only generate two values: -1.0and 0.0 .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don't want to affect other test case which still use the default max=50 to generate integers in [50, 50). But For the FullyConnectedOp, I want to generate relative small numbers. With the given code, the elements will be -1 and 0. Any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've no idea about why the range was set to [-50, 50) previously, and I can't figure out any specific reasons to use this range for the tests (any upper bound test?). It'll be great if you have any background for it.
But anyway, the tensors with only two values (-1 and 0, 50% are 0) might not be a good candidate for the tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, i will change to generate float numbers in [-max, max) rather than integer numbers. Previously I thought sparse (say 50% zeros) is also a way to avoid float number accumulation error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With memcmp to check the results, then the only choice is integer numbers. Is it reasonable to check the results by AssertEqual within a small enough threshold like 1e-6, then we can keep the floating number with better distribution?
Or we can just increase max to filling more different numbers other than only -1 and 0.
What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reverted the changes for floating numbers. Changing ``memcmptoAssertEqual` is out of the scope of this PR, so I will keep it as is.

Or we can just increase max to filling more different numbers other than only -1 and 0.

I was thinking about including number 2 into the generated tensor but found that with the given shapes, there still has chance to get error. That means for the worst case, the intermediate accumulation value will be > 2^24, so the 1 will be ignored when accumulating another 1 to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's keep it as is for now.

} else {
data[i] = i % 100 - 50;
data[i] = i % (max * 2) - max;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, how about change to something like data[i] = i * 2.0 / size - 1.0 to generate [-1.0, 1.0)?

}
}


// Init arrays with the specified layout.
inline static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::desc &desc,
bool is_rand = false) {
InitDefaultArray(arr, is_rand);
bool is_rand = false, int max = 50) {
InitDefaultArray(arr, is_rand, max);
arr->MKLDNNDataReorderAsync(desc);
arr->WaitToRead();
}
Expand Down Expand Up @@ -330,7 +330,7 @@ inline void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
*/
inline std::vector<NDArrayAttrs> GetTestInputArrays(
int types = ArrayTypes::All, bool rand = false,
std::vector<float> scale = {1}, bool spatial_data_format = false) {
std::vector<float> scale = {1}, bool spatial_data_format = false, int max = 50) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the new parameter for better usability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestArrayShapes tas = GetTestArrayShapes(spatial_data_format);
std::vector<mxnet::TShape> shapes = tas.shapes;
std::vector<mkldnn::memory::desc> mds = tas.mds;
Expand All @@ -349,14 +349,14 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
// Type 1.
NDArray arr(shape, Context());
if (types & ArrayTypes::Normal) {
InitDefaultArray(&arr, rand);
InitDefaultArray(&arr, rand, max);
in_arrs.emplace_back(arr, "Normal NDArray");
}

// Type 4
arr = NDArray(shape, Context());
if (types & ArrayTypes::NormalReshaped) {
InitDefaultArray(&arr, rand);
InitDefaultArray(&arr, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount),
"Reshaped Normal NDArray");
}
Expand All @@ -379,19 +379,19 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
if (shape.ndim() == md.data.ndims && IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNN) {
desc_str = "MKLDNN NDArray";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
} else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNDiffShape) {
desc_str = "MKLDNN NDArray with different shape";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
} else if (shape.ndim() != md.data.ndims && types & ArrayTypes::MKLDNNDiffDim) {
std::stringstream ss;
ss << "MKLDNN NDArray with different dim " <<
shape.ndim() << "/" << md.data.ndims;
desc_str = ss.str();
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
}

Expand All @@ -401,20 +401,20 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
if (shape.ndim() == md.data.ndims && IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNReshaped) {
desc_str = "Reshaped MKLDNN NDArray";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
} else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNReshapedDiffShape) {
desc_str = "Reshaped MKLDNN NDArray with different shape";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
} else if (shape.ndim() != md.data.ndims
&& types & ArrayTypes::MKLDNNReshapedDiffDim) {
std::stringstream ss;
ss << "MKLDNN NDArray with different dim " <<
shape.ndim() << "/" << md.data.ndims;
desc_str = ss.str();
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
}
}
Expand Down Expand Up @@ -445,7 +445,7 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
inline std::vector<NDArrayAttrs> GetTestOutputArrays(
const mxnet::TShape &shp,
const std::vector<mkldnn::memory::desc> &mds,
std::vector<float>scale = {1}, bool rand = true, int types = ArrayTypes::All) {
std::vector<float>scale = {1}, bool rand = true, int types = ArrayTypes::All, int max = 50) {
mxnet::TShape shape = shp;

for (int dim = 0; dim < scale.size(); dim++)
Expand All @@ -458,15 +458,15 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(

if (types & ArrayTypes::Normal) {
in_arrs.emplace_back(arr, "Normal NDArray");
InitDefaultArray(&in_arrs.back().arr, rand);
InitDefaultArray(&in_arrs.back().arr, rand, max);
}

mxnet::TShape tmp_shape = shape;
if (types & ArrayTypes::NormalReshaped) {
// Type 4.
tmp_shape[0] = shape[0] * 2;
NDArray arr0(tmp_shape, Context());
InitDefaultArray(&arr0, rand);
InitDefaultArray(&arr0, rand, max);
in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1), "Reshaped NDArray");
}

Expand All @@ -477,7 +477,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size();
NDArray arr1(s, Context());
arr1 = arr1.AsArray(shape, arr1.dtype());
InitDefaultArray(&arr1, rand);
InitDefaultArray(&arr1, rand, max);
in_arrs.emplace_back(arr1, "Reused NDArray");
}

Expand All @@ -486,7 +486,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag);
NDArray arr2(s, Context(), true, mshadow::kUint8);
arr2 = arr2.AsArray(shape, mshadow::default_type_flag);
InitDefaultArray(&arr2, rand);
InitDefaultArray(&arr2, rand, max);
in_arrs.emplace_back(arr2, "Reused NDArray with diff data type");
}

Expand All @@ -496,7 +496,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
NDArray arr3(s, Context(), true, mshadow::kUint8);
tmp_shape[0] = shape[0] * 2;
arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag);
InitDefaultArray(&arr3, rand);
InitDefaultArray(&arr3, rand, max);
in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1), "Reused+Reshaped NDArray");
}

Expand All @@ -523,7 +523,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
if ((types & ArrayTypes::MKLDNN && shape.ndim() == md.data.ndims) ||
(types & ArrayTypes::MKLDNNDiffDim && shape.ndim() != md.data.ndims)) {
in_arrs.emplace_back(arr, desc_str);
InitMKLDNNArray(&in_arrs.back().arr, md, rand);
InitMKLDNNArray(&in_arrs.back().arr, md, rand, max);
}

// Type 8, 9.
Expand All @@ -532,7 +532,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size();
NDArray arr = NDArray(s, Context());
arr = arr.AsArray(shape, arr.dtype());
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
desc_str = "Reused MKLDNN NDArray";
if (shape.ndim() != md.data.ndims) {
std::stringstream ss;
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/include/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ static void AssertEqual(const std::vector<NDArray *> &in_arrs,
static_cast<mshadow::default_real_t *>(blob2.dptr_);
for (int i = 0; i < tmp1.shape().Size(); i++) {
float abs_err = fabs((d1[i]) - (d2[i]));
ASSERT_LE(abs_err, (atol + rtol * fabs(d2[i])));
ASSERT_LE(abs_err, (atol + rtol * fabs(d2[i])))
<< "index: " << i << ", " << d1[i] << " vs " << d2[i];
}
}
}
Expand Down
17 changes: 9 additions & 8 deletions tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {

for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, false, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, false, forward_attrs.output_types);
}

for (int i = 0; i < forward_attrs.num_inputs; i++)
Expand Down Expand Up @@ -897,7 +897,8 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::desc> mds = tas.mds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true);
std::vector<NDArrayAttrs> in_arrs =
GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

Expand Down Expand Up @@ -932,9 +933,9 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards

for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(out_shape, mds, {1}, false, forward_attrs.output_types, 1);
ex_out_arrs[i] =
GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(out_shape, mds, {1}, false, forward_attrs.output_types, 1);
}

for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
Expand All @@ -960,14 +961,14 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
backwards_input[1] = inputs[0]; // input
backwards_input[2] = inputs[1]; // weights

auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1)[i1];
NDArray back_weights(wt_shape, Context());
NDArray back_bias(bias_shape, Context());
backwards_outputs[0] = &tmp_output.arr;
backwards_outputs[1] = &back_weights;
backwards_outputs[2] = &back_bias;

auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1)[i1];
NDArray back_ex_weights(wt_shape, Context());
NDArray back_ex_bias(bias_shape, Context());
backwards_ex_outputs[0] = &tmp_output2.arr;
Expand All @@ -986,7 +987,7 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
AssertEqual(backwards_outputs, backwards_ex_outputs, 1e-6, 1e-6);
}
}
}
Expand Down