From 82a63cc14d8763ecc6cbf588f9ab335ba6b4699d Mon Sep 17 00:00:00 2001 From: alsadovn Date: Thu, 8 Aug 2019 17:19:00 -0700 Subject: [PATCH] Fix for CPU random ops seed narrowing conversion. --- .../core/providers/cpu/generator/random.h | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 6ef8d2c460553..639341d1a29cc 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -20,11 +20,14 @@ class RandomNormal final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; ORT_ENFORCE(info.GetAttr("dtype", &dtype).IsOK()); @@ -60,11 +63,14 @@ class RandomNormalLike final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; if (info.GetAttr("dtype", &dtype).IsOK()) { @@ -94,11 +100,14 @@ class RandomUniform final : public OpKernel { // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; ORT_ENFORCE(info.GetAttr("dtype", &dtype).IsOK()); @@ -131,11 +140,14 @@ class RandomUniformLike final : public OpKernel { ORT_ENFORCE(info.GetAttr("low", &low_).IsOK()); // read optional seed attribute and generate if not provided float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t dtype; if (info.GetAttr("dtype", &dtype).IsOK()) { @@ -163,11 +175,14 @@ class Multinomial final : public OpKernel { ORT_ENFORCE(info.GetAttr("sample_size", &num_samples_).IsOK()); float seed = 0.f; - if (!info.GetAttr("seed", &seed).IsOK()) { - seed = gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; + } + else { + generator_ = std::default_random_engine{ + gsl::narrow_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()) + }; } - - generator_ = std::default_random_engine{gsl::narrow_cast(seed)}; int64_t output_dtype_tmp; if (!info.GetAttr("dtype", &output_dtype_tmp).IsOK()) {