Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable provider unit tests for TensorRT #802

Merged
merged 39 commits into from
Apr 18, 2019
Merged

Enable provider unit tests for TensorRT #802

merged 39 commits into from
Apr 18, 2019

Conversation

stevenlix
Copy link
Contributor

No description provided.

@stevenlix stevenlix requested a review from jywu-msft April 9, 2019 23:27
@stevenlix stevenlix requested a review from a team as a code owner April 9, 2019 23:27
@jywu-msft
Copy link
Member

would have expected more changes than just enabling the unit tests for TensorRT...some are sure to fail in current state.
and sure enough, it looks like one of the tests is causing a segfault.

@jywu-msft
Copy link
Member

it would be nice to have some comments about why specific tests are being disabled.
ideally, we have categories of issues, and we can add comments to indicate which category a disabled test falls under.

static const int kMaxBatchSize = 1;
static const int kMaxWorkSpaceSize = 1 << 30;
static const int kMaxBatchSize = 13;
static const int kMaxWorkSpaceSize = 1 << 24;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this changing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to add an env variable for MaxWorkSpaceSize too.

@@ -50,6 +51,9 @@ void TestConvOp(const ConvOpAttributes& attributes,
if (!is_mkldnn_supported) {
excluded_providers.insert(kMklDnnExecutionProvider);
}
if (!is_tensorrt_supported) {
excluded_providers.insert(kTensorrtExecutionProvider);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none of the conv tests pass for tensorrt right?
in that case, maybe it's better to not have an option for "is_tensorrt_supported" at all. (our expectation is they all pass, once weights as inputs is supported?)
we can remove the is_mkldnn_supported option as well, I think that's also not needed anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that though all tests in the same file failed, the reason of failure may be different. We may need to enable part of the tests if some issues are fixed while others are not yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to aim for less code changes to improve readability and maintainability.
The current state is all the tests fail for conv,convtranspose and others. I think it's better in general to not add code which is not used.

@@ -45,7 +46,11 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
test.AddInput<float>(szNames[i], input_shapes[i], inputs[i]);
}
test.AddOutput<float>("Y", expected_output_shape, expected_output);
test.Run(expect_result, err_str);
std::unordered_set<std::string> excluded_providers;
if (!is_tensorrt_supported) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment to conv above.
maybe easier to just exclude tensorrt , and not have option for the case where we need to disable all the tests.

@@ -50,6 +50,7 @@ void TestConvOp(const ConvOpAttributes& attributes,
if (!is_mkldnn_supported) {
excluded_providers.insert(kMklDnnExecutionProvider);
}
excluded_providers.insert(kTensorrtExecutionProvider);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comment for why

@@ -45,7 +45,7 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
test.AddInput<float>(szNames[i], input_shapes[i], inputs[i]);
}
test.AddOutput<float>("Y", expected_output_shape, expected_output);
test.Run(expect_result, err_str);
test.Run(expect_result, err_str, {kTensorrtExecutionProvider});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. please try to add more comments broadly explaining why certain tests are disabled.

@@ -480,6 +480,8 @@ def setup_tensorrt_vars(args):
"tensorrt_home='{}' valid={}."
.format(tensorrt_home, tensorrt_home_valid))

os.environ["TENSORRT_MAX_BATCH_SIZE"] = "13"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment?
btw, why 13? is that the maximum batch size that we test in the unit tests?
if someone adds a test which requires larger, let's make it easy for them to figure out why.

private:
int max_batch_size_ = 1;
const int max_workspace_size_ = 1 << 30;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't we want to make workspace_size configurable too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per TensorRT doc, max workspace size_ should be set as large as possible and there is no need to make it configurable. Actually memory is allocated as needed when creating an IExecutionContext.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 << 30 doesn't strike me to mean "as large as possible"
how did you choose that value? if it can't guarantee it will work for all cases, then it would be safer to make it configurable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we should be considering the reverse scenario.
on devices that don't have as much memory (e.g. jetson nano) , one may want to limit the size of the workspace to be small.
it's better to make this configurable to cover both sides of the spectrum.

trt_builder->setMaxBatchSize(kMaxBatchSize);
trt_builder->setMaxWorkspaceSize(kMaxWorkSpaceSize);

const char* batch_env = getenv("TENSORRT_MAX_BATCH_SIZE");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better if the env variable started with ORT_
to ensure we don't potentially collide with real TENSORRT env variables

Copy link
Member

@jywu-msft jywu-msft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some more comments.

@stevenlix stevenlix merged commit f2694ab into master Apr 18, 2019
@raymondxyang raymondxyang deleted the stevenlix/trttest branch April 26, 2019 07:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants