diff --git a/samples/Azure.Management.Storage/Generated/CodeModel.yaml b/samples/Azure.Management.Storage/Generated/CodeModel.yaml
index cd2b712b4a6..44db8d108aa 100644
--- a/samples/Azure.Management.Storage/Generated/CodeModel.yaml
+++ b/samples/Azure.Management.Storage/Generated/CodeModel.yaml
@@ -25496,6 +25496,7 @@ operationGroups:
status: InProgress
x-ms-long-running-operation: true
x-ms-long-running-operation-options:
+ enable-interim-state: true
final-state-via: location
language: !Languages
default:
diff --git a/samples/Azure.Management.Storage/Generated/LongRunningOperation/StorageAccountRestoreBlobRangesOperation.cs b/samples/Azure.Management.Storage/Generated/LongRunningOperation/StorageAccountRestoreBlobRangesOperation.cs
new file mode 100644
index 00000000000..a7245e5b319
--- /dev/null
+++ b/samples/Azure.Management.Storage/Generated/LongRunningOperation/StorageAccountRestoreBlobRangesOperation.cs
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+//
+
+#nullable disable
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure;
+using Azure.Core;
+using Azure.Core.Pipeline;
+using Azure.Management.Storage.Models;
+using Azure.ResourceManager;
+
+namespace Azure.Management.Storage
+{
+ /// A class representing the specific long-running operation StorageAccountRestoreBlobRangesOperation.
+ public class StorageAccountRestoreBlobRangesOperation : ArmOperation
+ {
+ private readonly StorageArmOperation _operation;
+
+ private readonly IOperationSource _operationSource;
+
+ private readonly AsyncLockWithValue _stateLock;
+
+ private readonly Response _interimResponse;
+
+ /// Initializes a new instance of StorageAccountRestoreBlobRangesOperation for mocking.
+ protected StorageAccountRestoreBlobRangesOperation()
+ {
+ }
+
+ internal StorageAccountRestoreBlobRangesOperation(IOperationSource source, ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, Request request, Response response, OperationFinalStateVia finalStateVia)
+ {
+ _operation = new StorageArmOperation(source, clientDiagnostics, pipeline, request, response, finalStateVia);
+ _operationSource = source;
+ _stateLock = new AsyncLockWithValue();
+ _interimResponse = response;
+ }
+
+ ///
+#pragma warning disable CA1822
+ public override string Id => throw new NotImplementedException();
+#pragma warning restore CA1822
+
+ ///
+ public override BlobRestoreStatus Value => _operation.Value;
+
+ ///
+ public override bool HasValue => _operation.HasValue;
+
+ ///
+ public override bool HasCompleted => _operation.HasCompleted;
+
+ ///
+ public override Response GetRawResponse() => _operation.GetRawResponse();
+
+ ///
+ public override Response UpdateStatus(CancellationToken cancellationToken = default) => _operation.UpdateStatus(cancellationToken);
+
+ ///
+ public override ValueTask UpdateStatusAsync(CancellationToken cancellationToken = default) => _operation.UpdateStatusAsync(cancellationToken);
+
+ ///
+ public override Response WaitForCompletion(CancellationToken cancellationToken = default) => _operation.WaitForCompletion(cancellationToken);
+
+ ///
+ public override Response WaitForCompletion(TimeSpan pollingInterval, CancellationToken cancellationToken = default) => _operation.WaitForCompletion(pollingInterval, cancellationToken);
+
+ ///
+ public override ValueTask> WaitForCompletionAsync(CancellationToken cancellationToken = default) => _operation.WaitForCompletionAsync(cancellationToken);
+
+ ///
+ public override ValueTask> WaitForCompletionAsync(TimeSpan pollingInterval, CancellationToken cancellationToken = default) => _operation.WaitForCompletionAsync(pollingInterval, cancellationToken);
+
+ /// Gets interim status of the long-running operation.
+ /// The cancellation token to use.
+ /// The interim status of the long-running operation.
+ public virtual async ValueTask GetCurrentStatusAsync(CancellationToken cancellationToken = default) => await GetCurrentState(true, cancellationToken).ConfigureAwait(false);
+
+ /// Gets interim status of the long-running operation.
+ /// The cancellation token to use.
+ /// The interim status of the long-running operation.
+ public virtual BlobRestoreStatus GetCurrentStatus(CancellationToken cancellationToken = default) => GetCurrentState(false, cancellationToken).EnsureCompleted();
+
+ private async ValueTask GetCurrentState(bool async, CancellationToken cancellationToken)
+ {
+ using var asyncLock = await _stateLock.GetLockOrValueAsync(async, cancellationToken).ConfigureAwait(false);
+ if (asyncLock.HasValue)
+ {
+ return asyncLock.Value;
+ }
+ var val = async ? await _operationSource.CreateResultAsync(_interimResponse, cancellationToken).ConfigureAwait(false)
+ : _operationSource.CreateResult(_interimResponse, cancellationToken);
+ asyncLock.SetValue(val);
+ return val;
+ }
+ }
+}
diff --git a/samples/Azure.Management.Storage/Generated/StorageAccountResource.cs b/samples/Azure.Management.Storage/Generated/StorageAccountResource.cs
index 12ce20b8f40..39586640dfa 100644
--- a/samples/Azure.Management.Storage/Generated/StorageAccountResource.cs
+++ b/samples/Azure.Management.Storage/Generated/StorageAccountResource.cs
@@ -787,7 +787,7 @@ public virtual ArmOperation AbortHierarchicalNamespaceMigration(WaitUntil waitUn
/// The parameters to provide for restore blob ranges.
/// The cancellation token to use.
/// is null.
- public virtual async Task> RestoreBlobRangesAsync(WaitUntil waitUntil, BlobRestoreContent content, CancellationToken cancellationToken = default)
+ public virtual async Task RestoreBlobRangesAsync(WaitUntil waitUntil, BlobRestoreContent content, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(content, nameof(content));
@@ -796,7 +796,7 @@ public virtual async Task> RestoreBlobRangesAsyn
try
{
var response = await _storageAccountRestClient.RestoreBlobRangesAsync(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content, cancellationToken).ConfigureAwait(false);
- var operation = new StorageArmOperation(new BlobRestoreStatusOperationSource(), _storageAccountClientDiagnostics, Pipeline, _storageAccountRestClient.CreateRestoreBlobRangesRequest(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content).Request, response, OperationFinalStateVia.Location);
+ var operation = new StorageAccountRestoreBlobRangesOperation(new BlobRestoreStatusOperationSource(), _storageAccountClientDiagnostics, Pipeline, _storageAccountRestClient.CreateRestoreBlobRangesRequest(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content).Request, response, OperationFinalStateVia.Location);
if (waitUntil == WaitUntil.Completed)
await operation.WaitForCompletionAsync(cancellationToken).ConfigureAwait(false);
return operation;
@@ -817,7 +817,7 @@ public virtual async Task> RestoreBlobRangesAsyn
/// The parameters to provide for restore blob ranges.
/// The cancellation token to use.
/// is null.
- public virtual ArmOperation RestoreBlobRanges(WaitUntil waitUntil, BlobRestoreContent content, CancellationToken cancellationToken = default)
+ public virtual StorageAccountRestoreBlobRangesOperation RestoreBlobRanges(WaitUntil waitUntil, BlobRestoreContent content, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(content, nameof(content));
@@ -826,7 +826,7 @@ public virtual ArmOperation RestoreBlobRanges(WaitUntil waitU
try
{
var response = _storageAccountRestClient.RestoreBlobRanges(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content, cancellationToken);
- var operation = new StorageArmOperation(new BlobRestoreStatusOperationSource(), _storageAccountClientDiagnostics, Pipeline, _storageAccountRestClient.CreateRestoreBlobRangesRequest(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content).Request, response, OperationFinalStateVia.Location);
+ var operation = new StorageAccountRestoreBlobRangesOperation(new BlobRestoreStatusOperationSource(), _storageAccountClientDiagnostics, Pipeline, _storageAccountRestClient.CreateRestoreBlobRangesRequest(Id.SubscriptionId, Id.ResourceGroupName, Id.Name, content).Request, response, OperationFinalStateVia.Location);
if (waitUntil == WaitUntil.Completed)
operation.WaitForCompletion(cancellationToken);
return operation;
diff --git a/samples/Azure.Management.Storage/readme.md b/samples/Azure.Management.Storage/readme.md
index db1741b578b..4bf638a0d2b 100644
--- a/samples/Azure.Management.Storage/readme.md
+++ b/samples/Azure.Management.Storage/readme.md
@@ -57,4 +57,7 @@ directive:
- from: swagger-document
where: $.paths["/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Storage/storageAccounts/{accountName}/fileServices/default/shares"].get.parameters[4].type
transform: return "integer"
+ - from: swagger-document
+ where: $.paths["/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Storage/storageAccounts/{accountName}/restoreBlobRanges"].post
+ transform: $["x-ms-long-running-operation-options"]["enable-interim-state"] = true
```
diff --git a/src/AutoRest.CSharp/Common/Input/CodeModelPartials.cs b/src/AutoRest.CSharp/Common/Input/CodeModelPartials.cs
index 11aa6a5cccc..855a1141bea 100644
--- a/src/AutoRest.CSharp/Common/Input/CodeModelPartials.cs
+++ b/src/AutoRest.CSharp/Common/Input/CodeModelPartials.cs
@@ -35,6 +35,20 @@ public OperationFinalStateVia LongRunningFinalStateVia
}
}
+ // This is a new extension introduced by generator to control whether interim state returns are supported in lro.
+ public bool IsInterimLongRunningStateEnabled
+ {
+ get
+ {
+ var isInterimStatusEnabled = Extensions.GetValue>("x-ms-long-running-operation-options")?.GetValue("enable-interim-state");
+ return isInterimStatusEnabled switch
+ {
+ "true" => true,
+ _ => false,
+ };
+ }
+ }
+
public string? Accessibility => Extensions.GetValue("x-accessibility");
public ServiceResponse LongRunningInitialResponse
diff --git a/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtOutputLibrary.cs b/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtOutputLibrary.cs
index 1186dd371d9..41ca938f07a 100644
--- a/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtOutputLibrary.cs
+++ b/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtOutputLibrary.cs
@@ -130,6 +130,8 @@ private static void ApplyGlobalConfigurations()
public Dictionary CSharpTypeToOperationSource { get; } = new Dictionary();
public IEnumerable OperationSources => CSharpTypeToOperationSource.Values;
+ public ICollection InterimOperations { get; } = new List();
+
private IEnumerable UpdateBodyParameters()
{
Dictionary usageCounts = new Dictionary();
diff --git a/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtTarget.cs b/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtTarget.cs
index 0c92bdb9dcd..c6dd57aadd9 100644
--- a/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtTarget.cs
+++ b/src/AutoRest.CSharp/Mgmt/AutoRest/MgmtTarget.cs
@@ -146,6 +146,13 @@ public static async Task ExecuteAsync(GeneratedCodeWorkspace project, CodeModel
lroWriter.Write();
AddGeneratedFile(project, lroWriter.Filename, lroWriter.ToString());
+ foreach (var interimOperation in MgmtContext.Library.InterimOperations.Distinct(LongRunningInterimOperation.LongRunningInterimOperationComparer))
+ {
+ var writer = new MgmtLongRunningInterimOperationWriter(interimOperation);
+ writer.Write();
+ AddGeneratedFile(project, $"LongRunningOperation/{interimOperation.TypeName}.cs", writer.ToString());
+ }
+
foreach (var operationSource in MgmtContext.Library.OperationSources)
{
var writer = new OperationSourceWriter(operationSource);
diff --git a/src/AutoRest.CSharp/Mgmt/Generation/MgmtClientBaseWriter.cs b/src/AutoRest.CSharp/Mgmt/Generation/MgmtClientBaseWriter.cs
index cf20a3d2054..03b0f513de1 100644
--- a/src/AutoRest.CSharp/Mgmt/Generation/MgmtClientBaseWriter.cs
+++ b/src/AutoRest.CSharp/Mgmt/Generation/MgmtClientBaseWriter.cs
@@ -835,10 +835,17 @@ protected virtual void WriteLROMethodBranch(MgmtRestOperation operation, IEnumer
protected virtual void WriteLROResponse(string diagnosticsVariableName, string pipelineVariableName, MgmtRestOperation operation, IEnumerable parameterMapping, bool async)
{
- _writer.Append($"var operation = new {LibraryArmOperation}");
- if (operation.ReturnType.IsGenericType)
+ if (operation.InterimOperation is not null)
{
- _writer.Append($"<{operation.MgmtReturnType}>");
+ _writer.Append($"var operation = new {operation.InterimOperation.TypeName}");
+ }
+ else
+ {
+ _writer.Append($"var operation = new {LibraryArmOperation}");
+ if (operation.ReturnType.IsGenericType)
+ {
+ _writer.Append($"<{operation.MgmtReturnType}>");
+ }
}
_writer.Append($"(");
if (operation.IsFakeLongRunningOperation)
diff --git a/src/AutoRest.CSharp/Mgmt/Generation/MgmtLongRunningInterimOperationWriter.cs b/src/AutoRest.CSharp/Mgmt/Generation/MgmtLongRunningInterimOperationWriter.cs
new file mode 100644
index 00000000000..53c9ee7a3f6
--- /dev/null
+++ b/src/AutoRest.CSharp/Mgmt/Generation/MgmtLongRunningInterimOperationWriter.cs
@@ -0,0 +1,147 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License. See License.txt in the project root for license information.
+
+using System;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoRest.CSharp.Generation.Writers;
+using AutoRest.CSharp.Mgmt.AutoRest;
+using AutoRest.CSharp.Mgmt.Output;
+using AutoRest.CSharp.Output.Models.Shared;
+using AutoRest.CSharp.Output.Models.Types;
+using Azure;
+using Azure.Core;
+using Azure.Core.Pipeline;
+using Azure.ResourceManager;
+using Request = Azure.Core.Request;
+
+namespace AutoRest.CSharp.Mgmt.Generation
+{
+ internal class MgmtLongRunningInterimOperationWriter
+ {
+ private readonly CodeWriter _writer;
+ private readonly LongRunningInterimOperation _interimOperation;
+
+ public MgmtLongRunningInterimOperationWriter(LongRunningInterimOperation interimOperation)
+ {
+ _writer = new CodeWriter();
+ _interimOperation = interimOperation;
+ }
+
+ public void Write()
+ {
+ using (_writer.Namespace(MgmtContext.Context.DefaultNamespace))
+ {
+ _writer.WriteXmlDocumentationSummary($"A class representing the specific long-running operation {_interimOperation.TypeName}.");
+ _writer.Line($"public class {_interimOperation.TypeName} : {_interimOperation.BaseClassType}");
+ using (_writer.Scope())
+ {
+ _writer.Line($"private readonly {_interimOperation.OperationType} _operation;");
+ _writer.Line();
+
+ _writer.Line($"private readonly {_interimOperation.IOperationSourceType} _operationSource;");
+ _writer.Line();
+
+ _writer.Line($"private readonly {_interimOperation.StateLockType} _stateLock;");
+ _writer.Line();
+
+ _writer.Line($"private readonly {typeof(Response)} _interimResponse;");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationSummary($"Initializes a new instance of {_interimOperation.TypeName} for mocking.");
+ using (_writer.Scope($"protected {_interimOperation.TypeName}()"))
+ {
+ }
+ _writer.Line();
+
+ using (_writer.Scope($"internal {_interimOperation.TypeName}({_interimOperation.IOperationSourceType} source, {typeof(ClientDiagnostics)} clientDiagnostics, {typeof(HttpPipeline)} pipeline, {typeof(Request)} request, {typeof(Response)} response, {typeof(OperationFinalStateVia)} finalStateVia)"))
+ {
+ _writer.Line($"_operation = new {_interimOperation.OperationType}(source, clientDiagnostics, pipeline, request, response, finalStateVia);");
+ _writer.Line($"_operationSource = source;");
+ _writer.Line($"_stateLock = new {_interimOperation.StateLockType}();");
+ _writer.Line($"_interimResponse = response;");
+ }
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer
+ .LineRaw("#pragma warning disable CA1822")
+ .LineRaw("public override string Id => throw new NotImplementedException();")
+ .LineRaw("#pragma warning restore CA1822")
+ .Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {_interimOperation.ReturnType} Value => _operation.Value;");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override bool HasValue => _operation.HasValue;");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override bool HasCompleted => _operation.HasCompleted;");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {typeof(Response)} GetRawResponse() => _operation.GetRawResponse();");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {typeof(Response)} UpdateStatus({typeof(CancellationToken)} cancellationToken = default) => _operation.UpdateStatus(cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {typeof(ValueTask<>).MakeGenericType(typeof(Response))} UpdateStatusAsync({typeof(CancellationToken)} cancellationToken = default) => _operation.UpdateStatusAsync(cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {_interimOperation.ResponseType} WaitForCompletion({typeof(CancellationToken)} cancellationToken = default) => _operation.WaitForCompletion(cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {_interimOperation.ResponseType} WaitForCompletion({typeof(TimeSpan)} pollingInterval, {typeof(CancellationToken)} cancellationToken = default) => _operation.WaitForCompletion(pollingInterval, cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {typeof(ValueTask)}<{_interimOperation.ResponseType}> WaitForCompletionAsync({typeof(CancellationToken)} cancellationToken = default) => _operation.WaitForCompletionAsync(cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationInheritDoc();
+ _writer.Line($"public override {typeof(ValueTask)}<{_interimOperation.ResponseType}> WaitForCompletionAsync({typeof(TimeSpan)} pollingInterval, {typeof(CancellationToken)} cancellationToken = default) => _operation.WaitForCompletionAsync(pollingInterval, cancellationToken);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationSummary($"Gets interim status of the long-running operation.");
+ _writer.WriteXmlDocumentationParameter(KnownParameters.CancellationTokenParameter);
+ _writer.WriteXmlDocumentationReturns($"The interim status of the long-running operation.");
+ _writer.Line($"public virtual async {_interimOperation.ValueTaskType} GetCurrentStatusAsync({typeof(CancellationToken)} cancellationToken = default) => await GetCurrentState(true, cancellationToken).ConfigureAwait(false);");
+ _writer.Line();
+
+ _writer.WriteXmlDocumentationSummary($"Gets interim status of the long-running operation.");
+ _writer.WriteXmlDocumentationParameter(KnownParameters.CancellationTokenParameter);
+ _writer.WriteXmlDocumentationReturns($"The interim status of the long-running operation.");
+ _writer.Line($"public virtual {_interimOperation.ReturnType} GetCurrentStatus({typeof(CancellationToken)} cancellationToken = default) => GetCurrentState(false, cancellationToken).EnsureCompleted();");
+ _writer.Line();
+
+ using (_writer.Scope($"private async {_interimOperation.ValueTaskType} GetCurrentState({typeof(bool)} async, {typeof(CancellationToken)} cancellationToken)"))
+ {
+ _writer.Line($"using var asyncLock = await _stateLock.GetLockOrValueAsync(async, cancellationToken).ConfigureAwait(false);");
+ using (_writer.Scope($"if (asyncLock.HasValue)"))
+ {
+ _writer.Line($"return asyncLock.Value;");
+ }
+ _writer.Line($"var val = async ? await _operationSource.CreateResultAsync(_interimResponse, cancellationToken).ConfigureAwait(false)");
+ _writer.Line($"\t\t: _operationSource.CreateResult(_interimResponse, cancellationToken);");
+ _writer.Line($"asyncLock.SetValue(val);");
+ _writer.Line($"return val;");
+ }
+ }
+ }
+ }
+
+ public override string ToString()
+ {
+ return _writer.ToString();
+ }
+ }
+}
diff --git a/src/AutoRest.CSharp/Mgmt/Models/MgmtRestOperation.cs b/src/AutoRest.CSharp/Mgmt/Models/MgmtRestOperation.cs
index 59d555c7438..a8cca10944c 100644
--- a/src/AutoRest.CSharp/Mgmt/Models/MgmtRestOperation.cs
+++ b/src/AutoRest.CSharp/Mgmt/Models/MgmtRestOperation.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using AutoRest.CSharp.Generation.Types;
@@ -47,6 +48,8 @@ internal record MgmtRestOperation
public OperationSource? OperationSource { get; }
+ public LongRunningInterimOperation? InterimOperation { get; }
+
private Func? _returnsDescription;
public Func? ReturnsDescription => IsPagingOperation ? _returnsDescription ??= EnsureReturnsDescription() : null;
@@ -116,6 +119,7 @@ public MgmtRestOperation(Operation operation, RequestPath requestPath, RequestPa
FinalStateVia = operation.IsLongRunning ? operation.LongRunningFinalStateVia : null;
OriginalReturnType = operation.IsLongRunning ? GetFinalResponse() : Method.ReturnType;
OperationSource = GetOperationSource();
+ InterimOperation = GetInterimOperation();
}
public MgmtRestOperation(MgmtRestOperation other, string nameOverride, CSharpType? overrideReturnType, string overrideDescription, params Parameter[] overrideParameters)
@@ -137,6 +141,7 @@ public MgmtRestOperation(MgmtRestOperation other, string nameOverride, CSharpTyp
FinalStateVia = other.FinalStateVia;
OriginalReturnType = other.OriginalReturnType;
OperationSource = other.OperationSource;
+ InterimOperation = other.InterimOperation;
//modify some of the values
Name = nameOverride;
@@ -165,6 +170,25 @@ public MgmtRestOperation(MgmtRestOperation other, string nameOverride, CSharpTyp
return operationSource;
}
+ private LongRunningInterimOperation? GetInterimOperation()
+ {
+ if (!IsLongRunningOperation || IsFakeLongRunningOperation)
+ return null;
+
+ if (Operation.IsInterimLongRunningStateEnabled)
+ {
+ IEnumerable allSchemas = Operation.Responses.Select(r => r.ResponseSchema);
+ ImmutableHashSet schemas = allSchemas.ToImmutableHashSet();
+ if (MgmtReturnType is null || allSchemas.Count() != Operation.Responses.Count() || schemas.Count() != 1)
+ throw new NotSupportedException($"The interim state feature is only supported when all responses of the long running operation {Name} have the same shcema.");
+
+ var interimOperation = new LongRunningInterimOperation(MgmtReturnType, Resource, Name);
+ MgmtContext.Library.InterimOperations.Add(interimOperation);
+ return interimOperation;
+ }
+ return null;
+ }
+
private CSharpType? GetFinalResponse()
{
var finalSchema = Operation.LongRunningFinalResponse.ResponseSchema;
@@ -339,6 +363,10 @@ private CSharpType GetWrappedMgmtReturnType(CSharpType? originalType)
if (IsPagingOperation)
return originalType;
+ if (InterimOperation is not null)
+ return InterimOperation.InterimType;
+
+
return IsLongRunningOperation ? originalType.WrapOperation(false) : originalType.WrapResponse(false);
}
diff --git a/src/AutoRest.CSharp/Mgmt/Output/LongRunningInterimOperation.cs b/src/AutoRest.CSharp/Mgmt/Output/LongRunningInterimOperation.cs
new file mode 100644
index 00000000000..2b332bcb5f5
--- /dev/null
+++ b/src/AutoRest.CSharp/Mgmt/Output/LongRunningInterimOperation.cs
@@ -0,0 +1,87 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Linq;
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using AutoRest.CSharp.Input;
+using AutoRest.CSharp.Generation.Types;
+using AutoRest.CSharp.Mgmt.AutoRest;
+using Azure;
+using Azure.Core;
+using Azure.ResourceManager;
+
+namespace AutoRest.CSharp.Mgmt.Output
+{
+ internal class LongRunningInterimOperation
+ {
+ public LongRunningInterimOperation(CSharpType returnType, Resource? resource, string methodName)
+ {
+ ReturnType = returnType;
+ BaseClassType = new CSharpType(typeof(ArmOperation<>), returnType);
+ IOperationSourceType = new CSharpType(typeof(IOperationSource<>), returnType);
+ StateLockType = new CSharpType(typeof(AsyncLockWithValue<>), returnType);
+ ValueTaskType = new CSharpType(typeof(ValueTask<>), returnType);
+ ResponseType = new CSharpType(typeof(Response<>), returnType);
+ var trimmedNamespace = MgmtContext.Context.DefaultNamespace.Split('.').Last();
+ OperationType = $"{trimmedNamespace}ArmOperation<{returnType.Name}>";
+ var resourceName = resource != null ? resource.ResourceName : $"{trimmedNamespace}Extensions";
+ TypeName = $"{resourceName}{methodName}Operation";
+ var targetSchema = new ObjectSchema()
+ {
+ Language = new Languages()
+ {
+ Default = new Language()
+ {
+ Name = TypeName,
+ Namespace = MgmtContext.Context.DefaultNamespace
+ }
+ }
+ };
+ InterimType = new CSharpType(new MgmtObjectType(targetSchema), MgmtContext.Context.DefaultNamespace, TypeName);
+ }
+
+ public CSharpType ReturnType { get; }
+
+ public CSharpType BaseClassType { get; }
+
+ public CSharpType IOperationSourceType { get; }
+
+ public CSharpType StateLockType { get; }
+
+ public CSharpType ValueTaskType { get; }
+
+ public CSharpType ResponseType { get; }
+
+ public CSharpType InterimType { get; }
+
+ public string TypeName { get; }
+
+ public string OperationType { get; }
+
+ public static IEqualityComparer LongRunningInterimOperationComparer { get; } = new LongRunningInterimOperationComparerImplementation();
+
+ private class LongRunningInterimOperationComparerImplementation : IEqualityComparer
+ {
+ public bool Equals(LongRunningInterimOperation? x, LongRunningInterimOperation? y)
+ {
+ if (x is null || y is null)
+ {
+ return ReferenceEquals(x, y);
+ }
+
+ return x.TypeName == y.TypeName;
+ }
+
+ public int GetHashCode(LongRunningInterimOperation obj)
+ {
+
+ var hashCode = new HashCode();
+ hashCode.Add(obj.TypeName);
+
+ return hashCode.ToHashCode();
+ }
+ }
+ }
+}