Skip to content

Commit

Permalink
Brianma/linux (#2917)
Browse files Browse the repository at this point in the history
* don't include windows.h in cross-plat header

* add default case for switch statement

* signed/unsigned mismatch fix

Co-authored-by: Brian Martin <[email protected]>
  • Loading branch information
zhangxiang1993 and martinb35 authored Jan 28, 2020
1 parent 390ed0d commit eacb8a7
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 12 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/core/platform/telemetry.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <windows.h>
#include "core/platform/telemetry.h"
#include "core/platform/env.h"

Expand Down Expand Up @@ -67,7 +66,7 @@ void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_la
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
}

void Telemetry::LogExecutionProviderEvent(LUID adapterLuid) const {
void Telemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {
ORT_UNUSED_PARAMETER(adapterLuid);
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/platform/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Telemetry {

virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const;

virtual void LogExecutionProviderEvent(LUID adapterLuid) const;
virtual void LogExecutionProviderEvent(LUID* adapterLuid) const;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"));
}

void WindowsTelemetry::LogExecutionProviderEvent(LUID adapterLuid) const {
void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {
if (global_register_count_ == 0 || enabled_ == false)
return;

Expand All @@ -225,8 +225,8 @@ void WindowsTelemetry::LogExecutionProviderEvent(LUID adapterLuid) const {
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
TraceLoggingUInt32(adapterLuid.LowPart, "adapterLuidLowPart"),
TraceLoggingUInt32(adapterLuid.HighPart, "adapterLuidHighPart"));
TraceLoggingUInt32(adapterLuid->LowPart, "adapterLuidLowPart"),
TraceLoggingUInt32(adapterLuid->HighPart, "adapterLuidHighPart"));
}

} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/platform/windows/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class WindowsTelemetry : public Telemetry {

void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override;

void LogExecutionProviderEvent(LUID adapterLuid) const override;
void LogExecutionProviderEvent(LUID* adapterLuid) const override;

private:
static OrtMutex mutex_;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(ID
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(dml_device->GetParentDevice(IID_PPV_ARGS(&d3d12_device)));
const Env& env = Env::Default();
env.GetTelemetryProvider().LogExecutionProviderEvent(d3d12_device->GetAdapterLuid());
env.GetTelemetryProvider().LogExecutionProviderEvent(&d3d12_device->GetAdapterLuid());

return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue);
}
Expand Down
11 changes: 8 additions & 3 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,19 +752,24 @@ static bool ModelUseFP16Helper(const onnx::TypeProto& type_proto) {
return true;
}
}
} break;
break;
}
case ::onnx::TypeProto::ValueCase::kSequenceType: {
if (type_proto.has_sequence_type()) {
auto& sequence_type = type_proto.sequence_type();
return ModelUseFP16Helper(sequence_type.elem_type());
}
} break;
break;
}
case ::onnx::TypeProto::ValueCase::kMapType: {
if (type_proto.has_map_type()) {
auto& map_type = type_proto.map_type();
return ModelUseFP16Helper(map_type.value_type());
}
} break;
break;
}
default:
break;
}
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion winml/adapter/DmlOrtSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ HRESULT DmlOrtSessionBuilder::CreateSession(
auto session = std::make_unique<onnxruntime::InferenceSession>(options->value);

const onnxruntime::Env& env = onnxruntime::Env::Default();
env.GetTelemetryProvider().LogExecutionProviderEvent(p_d3d_device->GetAdapterLuid());
LUID temp_LUID = p_d3d_device->GetAdapterLuid();
env.GetTelemetryProvider().LogExecutionProviderEvent(&temp_LUID);
// Cache the provider's raw pointer
*pp_provider = gpu_provider.get();

Expand Down

0 comments on commit eacb8a7

Please sign in to comment.