Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into snnn/update_test_data
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jul 29, 2019
2 parents 32ad050 + cf73f63 commit d14a5db
Show file tree
Hide file tree
Showing 17 changed files with 56 additions and 40 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Tensor final {
/**
The number of bytes of data.
*/
size_t Size() const {
size_t SizeInBytes() const {
size_t ret;
int64_t l = shape_.Size();
if (l >= static_cast<int64_t>(std::numeric_limits<ptrdiff_t>::max())) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ common::Status CPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int /
return Status::OK();
}
// Copying only happens between two same size tensors.
ORT_ENFORCE(src.Size() == dst.Size());
memcpy(dst_data, src_data, src.Size());
ORT_ENFORCE(src.SizeInBytes() == dst.SizeInBytes());
memcpy(dst_data, src_data, src.SizeInBytes());
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const std:

tensor_proto.set_data_type(tensor_proto_type.tensor_type().elem_type());

tensor_proto.set_raw_data(tensor.DataRaw(), tensor.Size());
tensor_proto.set_raw_data(tensor.DataRaw(), tensor.SizeInBytes());

return tensor_proto;
}
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
auto matmul_input_defs = matmul_node.MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();

// Gemm only support float, so the inputs of MatMul
// Gemm requires that inputs be the same data type and both floating point (float32/float16).
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
if ((*matmul_type) != (*add_type)) {
continue;
}
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") {
continue;
}

Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/cpu/controlflow/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ void LoopImpl::SaveOutputsAndUpdateFeeds(const std::vector<OrtValue>& last_outpu

Status LoopImpl::ConcatenateLoopOutput(std::vector<OrtValue>& per_iteration_output, int output_index) {
const auto& first_output = per_iteration_output.front().Get<Tensor>();
size_t bytes_per_iteration = first_output.Size();
size_t bytes_per_iteration = first_output.SizeInBytes();
const auto& per_iteration_shape = first_output.Shape();
const auto& per_iteration_dims = per_iteration_shape.GetDims();

Expand All @@ -317,19 +317,19 @@ Status LoopImpl::ConcatenateLoopOutput(std::vector<OrtValue>& per_iteration_outp
// we can't easily use a C++ template for the tensor element type,
// so use a span for some protection but work in bytes
gsl::span<gsl::byte> output_span = gsl::make_span<gsl::byte>(static_cast<gsl::byte*>(output->MutableDataRaw()),
output->Size());
output->SizeInBytes());

for (int64_t i = 0; i < num_iterations; ++i) {
auto& ort_value = per_iteration_output[i];
auto& iteration_data = ort_value.Get<Tensor>();

// sanity check
if (bytes_per_iteration != iteration_data.Size()) {
if (bytes_per_iteration != iteration_data.SizeInBytes()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Inconsistent shape in loop output for output ", output_index,
" Expected:", per_iteration_shape, " Got:", iteration_data.Shape());
}

auto num_bytes = iteration_data.Size();
auto num_bytes = iteration_data.SizeInBytes();
auto src = gsl::make_span<const gsl::byte>(static_cast<const gsl::byte*>(iteration_data.DataRaw()), num_bytes);
auto dst = output_span.subspan(i * bytes_per_iteration, bytes_per_iteration);
gsl::copy(src, dst);
Expand Down Expand Up @@ -382,8 +382,8 @@ Status LoopImpl::Execute(FeedsFetchesManager* ffm, const FeedsFetchesManager* ca
auto copy_tensor_from_mlvalue_to_output = [this](const OrtValue& input, int output_idx) {
auto& data = input.Get<Tensor>();
Tensor* output = context_.Output(output_idx, data.Shape());
auto src = gsl::make_span<const gsl::byte>(static_cast<const gsl::byte*>(data.DataRaw()), data.Size());
auto dst = gsl::make_span<gsl::byte>(static_cast<gsl::byte*>(output->MutableDataRaw()), output->Size());
auto src = gsl::make_span<const gsl::byte>(static_cast<const gsl::byte*>(data.DataRaw()), data.SizeInBytes());
auto dst = gsl::make_span<gsl::byte>(static_cast<gsl::byte*>(output->MutableDataRaw()), output->SizeInBytes());
gsl::copy(src, dst);
};

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/controlflow/scan_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class OutputIterator {
// set the output for the current iteration to zeros. used for short sequence lengths
void ZeroOutCurrent() {
auto* tensor = (**this).GetMutable<Tensor>();
memset(tensor->MutableDataRaw(), 0, tensor->Size());
memset(tensor->MutableDataRaw(), 0, tensor->SizeInBytes());
}

const OrtValue& GetOutput() const {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, co
}

const auto input_elements = input_data_shape.Size();
const auto total_input_bytes = data_input->Size();
const auto total_input_bytes = data_input->SizeInBytes();

const auto* src_base = static_cast<const Tdata*>(data_input->DataRaw());
auto* dst_base = static_cast<Tdata*>(data_output->MutableDataRaw());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Status Size::Compute(OpKernelContext* ctx) const {
TensorShape scalar_shape;
Tensor* p_output_tensor = ctx->Output(0, scalar_shape);
auto* p_output_scalar = p_output_tensor->template MutableData<int64_t>();
assert(p_output_tensor->Size() == sizeof(int64_t));
assert(p_output_tensor->SizeInBytes() == sizeof(int64_t));

*p_output_scalar = input_tensor->Shape().Size();

Expand Down
19 changes: 11 additions & 8 deletions onnxruntime/core/providers/cpu/tensor/squeeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ class SqueezeBase {
protected:
explicit SqueezeBase(const OpKernelInfo& info) {
std::vector<int64_t> axes;
Status status = info.GetAttrs<int64_t>("axes", axes);
ORT_ENFORCE(status.IsOK(), "Attribute axes is not set.");
// Parse attribute 'axes'
Status status = info.GetAttrs<int64_t>("axes", axes);

// Handle out of order and repeating dims.
std::sort(axes.begin(), axes.end());
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
axes_ = axes;
// Handle out of order and repeating dims when 'axes' exists.
if (status.IsOK()) {
std::sort(axes.begin(), axes.end());
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
axes_ = axes;
}
}

static std::vector<int64_t> ComputeOutputShape(
Expand All @@ -28,7 +30,8 @@ class SqueezeBase {
size_t j = 0;
std::vector<int64_t> output_shape;
for (size_t i = 0; i < input_shape.NumDimensions(); ++i) {
if (j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) {
if ((j < axes.NumDimensions() && axes[j] == static_cast<int64_t>(i)) ||
(axes.NumDimensions() == 0 && input_shape[i] == 1)) {
ORT_ENFORCE(input_shape[i] == 1, "Dimension of input ", i, " must be 1 instead of ", input_shape[i],
". shape=", input_shape);
++j;
Expand Down Expand Up @@ -59,4 +62,4 @@ class Squeeze final : public OpKernel, public SqueezeBase {
}
};

} // namespace onnxruntime
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa
Tensor* Y = context->Output(0, Y_dims);

if (no_scale) {
memcpy(Y->MutableDataRaw(), X->DataRaw(), Y->Size());
memcpy(Y->MutableDataRaw(), X->DataRaw(), Y->SizeInBytes());
return Status::OK();
}

Expand Down
7 changes: 0 additions & 7 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCudaExecutionProvider}, device_id_(info.device_id) {
CUDA_CALL_THROW(cudaSetDevice(device_id_));
// create streams, default is nullptr
streams_[kCudaStreamDefault] = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking));
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));

DeviceAllocatorRegistrationInfo default_allocator_info(
{OrtMemTypeDefault, [](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
Expand All @@ -93,9 +89,6 @@ CUDAExecutionProvider::~CUDAExecutionProvider() {
CUDA_CALL_THROW(cudaEventDestroy(e));
it = deferred_release_cpu_ptr_.erase(it);
}
CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyIn]));
CUDA_CALL_THROW(cudaStreamDestroy(streams_[kCudaStreamCopyOut]));

ReleasePerThreadStuffs();
}

Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
return GetPerThreadContext().CudnnHandle();
}

cudaStream_t GetStream(int queue_id) const {
ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalCudaStreams);
return streams_[queue_id];
}

template <typename T>
const T* GetConstOnes(size_t count) {
return GetPerThreadContext().template GetConstOnes<T>(count);
Expand All @@ -69,7 +64,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
int GetDeviceId() const { return device_id_; }

private:
cudaStream_t streams_[kTotalCudaStreams];
int device_id_;

struct DeferredReleaseCPUPtrs {
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ GPUDataTransfer::GPUDataTransfer() {
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));
}

GPUDataTransfer::~GPUDataTransfer() {
CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyIn]));
CUDA_CALL(cudaStreamDestroy(streams_[kCudaStreamCopyOut]));
}

bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED
|| dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED;
}

common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
size_t bytes = src.Size();
size_t bytes = src.SizeInBytes();
const void* src_data = src.DataRaw();
void* dst_data = dst.MutableDataRaw();

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/gpu_data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enum CUDAStreamType : int {
class GPUDataTransfer : public IDataTransfer {
public:
GPUDataTransfer();
~GPUDataTransfer();

bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/parallel_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct TestOp {
// success
Tensor* Y = ctx->Output(0, action_tensor.Shape());
void* target = Y->MutableData<int64_t>();
memcpy(target, action, action_tensor.Size());
memcpy(target, action, action_tensor.SizeInBytes());
break;
}
case 1: {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/sparse_kernels_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ This operator applies the Abs op element-wise to the input sparse-tensor.
// So, we copy indices/shape from input to output.
// TODO: Extend allocation-planner to enable such sharing.
const auto& input_indices = input->Indices();
memcpy(output->MutableIndices().MutableData<int64_t>(), input_indices.Data<int64_t>(), input_indices.Size());
memcpy(output->MutableIndices().MutableData<int64_t>(), input_indices.Data<int64_t>(), input_indices.SizeInBytes());
return Status::OK();
}
};
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/squeeze_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ TEST(SqueezeOpTest, Squeeze_1) {
test.Run();
}

TEST(SqueezeOpTest, Squeeze_Empty_Axes_1) {
OpTester test("Squeeze");
test.AddInput<float>("data", {1, 1, 4, 1}, std::vector<float>(4, 1.0f));
test.AddOutput<float>("squeezed", {4}, std::vector<float>(4, 1.0f));
// TensorRT doesn't seem to support missing 'axes'
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(SqueezeOpTest, Squeeze_Empty_Axes_2) {
OpTester test("Squeeze");
// nothing to "squeeze" out in the input shape
test.AddInput<float>("data", {2, 4}, std::vector<float>(8, 1.0f));
test.AddOutput<float>("squeezed", {2, 4}, std::vector<float>(8, 1.0f));
// TensorRT doesn't seem to support missing 'axes'
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(SqueezeOpTest, Squeeze_1_int32) {
OpTester test("Squeeze");
test.AddAttribute("axes", std::vector<int64_t>{0});
Expand Down

0 comments on commit d14a5db

Please sign in to comment.