Skip to content

Commit

Permalink
Onnx load model (#5782)
Browse files Browse the repository at this point in the history
* fixed onnx temp model deleting

* random file path fixed

* updates from pr

* Changes from PR comments.

* Changed how auto ml caches.

* PR fixes.

* Update src/Microsoft.ML.AutoML/API/ExperimentSettings.cs

Co-authored-by: Eric Erhardt <[email protected]>

* Tensorflow fixes from PR comments

* fixed filepath issues

Co-authored-by: Eric Erhardt <[email protected]>
  • Loading branch information
michaelgsharp and eerhardt authored May 18, 2021
1 parent bf31c94 commit 7fafbf3
Show file tree
Hide file tree
Showing 18 changed files with 79 additions and 57 deletions.
9 changes: 5 additions & 4 deletions src/Microsoft.ML.AutoML/API/ExperimentSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ public abstract class ExperimentSettings
public CancellationToken CancellationToken { get; set; }

/// <summary>
/// This is a pointer to a directory where all models trained during the AutoML experiment will be saved.
/// This is the name of the directory where all models trained during the AutoML experiment will be saved.
/// If <see langword="null"/>, models will be kept in memory instead of written to disk.
/// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
/// memory could cause a system to run out of memory.)
/// </summary>
/// <value>The default value is the directory named "Microsoft.ML.AutoML" in the current user's temporary folder.</value>
public DirectoryInfo CacheDirectory { get; set; }
/// <value>The default value is the directory named "Microsoft.ML.AutoML" in the in the location specified by the <see cref="MLContext.TempFilePath"/>.</value>
public string CacheDirectoryName { get; set; }

/// <summary>
/// Whether AutoML should cache before ML.NET trainers.
Expand All @@ -66,10 +66,11 @@ public ExperimentSettings()
{
MaxExperimentTimeInSeconds = 24 * 60 * 60;
CancellationToken = default;
CacheDirectory = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.AutoML"));
CacheDirectoryName = "Microsoft.ML.AutoML";
CacheBeforeTrainer = CacheBeforeTrainer.Auto;
MaxModels = int.MaxValue;
}

}

/// <summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.AutoML/API/RankingExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// </value>
public uint OptimizationMetricTruncationLevel { get; set; }

/// <summary>
/// Initializes a new instance of <see cref="RankingExperimentSettings"/>.
/// </summary>
public RankingExperimentSettings()
{
OptimizingMetric = RankingMetric.Ndcg;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public sealed class RegressionExperimentSettings : ExperimentSettings
/// </value>
public ICollection<RegressionTrainer> Trainers { get; }

/// <summary>
/// Initializes a new instance of <see cref="RegressionExperimentSettings"/>.
/// </summary>
public RegressionExperimentSettings()
{
OptimizingMetric = RegressionMetric.RSquared;
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.AutoML/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public Experiment(MLContext context,
_experimentSettings = experimentSettings;
_metricsAgent = metricsAgent;
_trainerAllowList = trainerAllowList;
_modelDirectory = GetModelDirectory(_experimentSettings.CacheDirectory);
_modelDirectory = GetModelDirectory(_context.TempFilePath, _experimentSettings.CacheDirectoryName);
_datasetColumnInfo = datasetColumnInfo;
_runner = runner;
_logger = logger;
Expand Down Expand Up @@ -140,7 +140,7 @@ public IList<TRunDetail> Execute()

// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
// maintain variability between training iterations.
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed;
int? mainContextSeed = ((IHostEnvironmentInternal)_context.Model.GetEnvironment()).Seed;
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null;

do
Expand Down Expand Up @@ -220,14 +220,14 @@ public IList<TRunDetail> Execute()
return iterationResults;
}

private static DirectoryInfo GetModelDirectory(DirectoryInfo rootDir)
private static DirectoryInfo GetModelDirectory(string tempDirectory, string cacheDirectoryName)
{
if (rootDir == null)
if (cacheDirectoryName == null)
{
return null;
}

var experimentDirFullPath = Path.Combine(rootDir.FullName, $"experiment_{Path.GetRandomFileName()}");
var experimentDirFullPath = Path.Combine(tempDirectory, cacheDirectoryName, $"experiment_{Path.GetRandomFileName()}");
var experimentDirInfo = new DirectoryInfo(experimentDirFullPath);
if (!experimentDirInfo.Exists)
{
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,17 @@ internal interface ICancelable
}

[BestFriend]
internal interface ISeededEnvironment : IHostEnvironment
internal interface IHostEnvironmentInternal : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }

/// <summary>
/// The location for the temp files created by ML.NET
/// </summary>
string TempFilePath { get; set; }
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal interface IMessageSource
/// query progress.
/// </summary>
[BestFriend]
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, ISeededEnvironment, IChannelProvider, ICancelable
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable
where TEnv : HostEnvironmentBase<TEnv>
{
void ICancelable.CancelExecution()
Expand Down Expand Up @@ -326,6 +326,10 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
}
}

#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value incase the user doesn't set it.
public string TempFilePath { get; set; } = System.IO.Path.GetTempPath();
#pragma warning restore MSML_NoInstanceInitializers

protected readonly TEnv Root;
// This is non-null iff this environment was a fork of another. Disposing a fork
// doesn't free temp files. That is handled when the master is disposed.
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ private void RunCore(IChannel ch, string cmd)

ILegacyDataLoader testPipe;
bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile);
var tempFilePath = hasOutfile ? null : Path.GetTempFileName();
var tempFilePath = hasOutfile ? null : Path.Combine(((IHostEnvironmentInternal)Host).TempFilePath, Path.GetRandomFileName());

using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView dat
}
else if(fallbackInEnvSeed)
{
ISeededEnvironment seededEnv = (ISeededEnvironment)env;
IHostEnvironmentInternal seededEnv = (IHostEnvironmentInternal)env;
seedToUse = seededEnv.Seed;
}
else
Expand Down
13 changes: 11 additions & 2 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML
/// create components for data preparation, feature enginering, training, prediction, model evaluation.
/// It also allows logging, execution control, and the ability set repeatable random numbers.
/// </summary>
public sealed class MLContext : ISeededEnvironment
public sealed class MLContext : IHostEnvironmentInternal
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
private readonly LocalEnvironment _env;
Expand Down Expand Up @@ -79,6 +79,15 @@ public sealed class MLContext : ISeededEnvironment
/// </summary>
public ComponentCatalog ComponentCatalog => _env.ComponentCatalog;

/// <summary>
/// Gets or sets the location for the temp files created by ML.NET.
/// </summary>
public string TempFilePath
{
get { return _env.TempFilePath; }
set { _env.TempFilePath = value; }
}

/// <summary>
/// Create the ML context.
/// </summary>
Expand Down Expand Up @@ -140,7 +149,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
int? ISeededEnvironment.Seed => _env.Seed;
int? IHostEnvironmentInternal.Seed => _env.Seed;

[BestFriend]
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
{
// Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
// to create a temporal file to store it and then call ONNXRuntime's API to load that file.
Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary);
Model = OnnxModel.CreateFromBytes(modelBytes, env, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary);
}
}
catch (OnnxRuntimeException e)
Expand Down Expand Up @@ -304,7 +304,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelFile)); });
ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelStream.Name)); });

Host.CheckNonEmpty(Inputs, nameof(Inputs));
ctx.Writer.Write(Inputs.Length);
Expand Down
47 changes: 23 additions & 24 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,9 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
/// </summary>
private readonly InferenceSession _session;
/// <summary>
/// Indicates if <see cref="ModelFile"/> is a temporal file created by <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>
/// or <see cref="CreateFromBytes(byte[])"/>. If <see langword="true"/>, <see cref="Dispose(bool)"/> should delete <see cref="ModelFile"/>.
/// The FileStream holding onto the loaded ONNX model.
/// </summary>
private bool _ownModelFile;
/// <summary>
/// The location where the used ONNX model loaded from.
/// </summary>
internal string ModelFile { get; }
internal FileStream ModelStream { get; }
/// <summary>
/// The ONNX model's information from ONNXRuntime's perspective. ML.NET can change the input and output of that model in some ways.
/// For example, ML.NET can shuffle the inputs so that the i-th ONNX input becomes the j-th input column of <see cref="OnnxTransformer"/>.
Expand All @@ -172,9 +167,7 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
bool ownModelFile=false, IDictionary<string, int[]> shapeDictionary = null)
{
ModelFile = modelFile;
// If we don't own the model file, _disposed should be false to prevent deleting user's file.
_ownModelFile = ownModelFile;
_disposed = false;

if (gpuDeviceId != null)
Expand Down Expand Up @@ -202,9 +195,15 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
{
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
// doesn't expose full type information via its C# APIs.
ModelFile = modelFile;
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(modelFile))
// If we own the model file set the DeleteOnClose flag so it is always deleted.
if (ownModelFile)
ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose);
else
ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read);

// The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

Expand Down Expand Up @@ -322,33 +321,35 @@ private static bool CheckOnnxShapeCompatibility(IEnumerable<int> left, IEnumerab

/// <summary>
/// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
/// With <see cref="CreateFromBytes(byte[])"/> and <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>,
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
/// </summary>
/// <param name="modelBytes">Bytes of the serialized model</param>
public static OnnxModel CreateFromBytes(byte[] modelBytes)
/// <param name="env">IHostEnvironment</param>
public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env)
{
return CreateFromBytes(modelBytes, null, false);
return CreateFromBytes(modelBytes, env, null, false);
}

/// <summary>
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
/// Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
/// With <see cref="CreateFromBytes(byte[])"/> and
/// <see cref="CreateFromBytes(byte[], int?, bool, IDictionary{string, int[]})"/>,
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and
/// <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
/// </summary>
/// <param name="modelBytes">Bytes of the serialized model.</param>
/// <param name="env">IHostEnvironment</param>
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
/// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
/// <param name="shapeDictionary">User-provided shapes. If the key "myTensorName" is associated
/// with the value [1, 3, 5], the shape of "myTensorName" will be set to [1, 3, 5].
/// The shape loaded from <paramref name="modelBytes"/> would be overwritten.</param>
/// <returns>An <see cref="OnnxModel"/></returns>
public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = null, bool fallbackToCpu = false,
public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env, int? gpuDeviceId = null, bool fallbackToCpu = false,
IDictionary<string, int[]> shapeDictionary = null)
{
var tempModelFile = Path.GetTempFileName();
var tempModelFile = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
File.WriteAllBytes(tempModelFile, modelBytes);
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu,
ownModelFile: true, shapeDictionary: shapeDictionary);
Expand All @@ -366,7 +367,7 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(List<NamedOnn
}

/// <summary>
/// Flag used to indicate if the unmanaged resources (aka the model file <see cref="ModelFile"/>
/// Flag used to indicate if the unmanaged resources (aka the model file handle <see cref="ModelStream"/>
/// and <see cref="_session"/>) have been deleted.
/// </summary>
private bool _disposed;
Expand All @@ -378,8 +379,7 @@ public void Dispose()
}

/// <summary>
/// There are two unmanaged resources we can dispose, <see cref="_session"/> and <see cref="ModelFile"/>
/// if <see cref="_ownModelFile"/> is <see langword="true"/>.
/// There are two unmanaged resources we can dispose, <see cref="_session"/> and <see cref="ModelStream"/>
/// </summary>
/// <param name="disposing"></param>
private void Dispose(bool disposing)
Expand All @@ -391,9 +391,8 @@ private void Dispose(bool disposing)
{
// First, we release the resource token by ONNXRuntime.
_session.Dispose();
// Second, we delete the model file if that file is not created by the user.
if (_ownModelFile && File.Exists(ModelFile))
File.Delete(ModelFile);
// Second, Dispose of the model file stream.
ModelStream.Dispose();
}
_disposed = true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
}

var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,9 @@ public void Dispose()
}
}

internal static string GetTemporaryDirectory()
internal static string GetTemporaryDirectory(IHostEnvironment env)
{
string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
string tempDirectory = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
Directory.CreateDirectory(tempDirectory);
return tempDirectory;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Vision/DnnRetrainTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private static DnnRetrainTransformer Create(IHostEnvironment env, ModelLoadConte
null, false, addBatchDimensionInput, 1);
}

var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
var tempDirPath = Path.GetFullPath(Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, nameof(DnnRetrainTransformer) + "_" + Guid.NewGuid()));
CreateFolderWithAclIfNotExists(env, tempDirPath);
try
{
Expand Down
Loading

0 comments on commit 7fafbf3

Please sign in to comment.