Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more code over to producer/consumer model #73331

Merged
merged 17 commits into from
May 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;

Expand Down Expand Up @@ -76,7 +77,6 @@ internal abstract class AbstractSuppressionBatchFixAllProvider : FixAllProvider
{
var cancellationToken = fixAllContext.CancellationToken;
var fixAllState = fixAllContext.State;
var fixesBag = new ConcurrentBag<(Diagnostic diagnostic, CodeAction action)>();

using (Logger.LogBlock(
FunctionId.CodeFixes_FixAllOccurrencesComputation_Document_Fixes,
Expand All @@ -86,73 +86,61 @@ internal abstract class AbstractSuppressionBatchFixAllProvider : FixAllProvider
cancellationToken.ThrowIfCancellationRequested();
var progressTracker = fixAllContext.Progress;

using var _1 = ArrayBuilder<Task>.GetInstance(out var tasks);
using var _2 = ArrayBuilder<Document>.GetInstance(out var documentsToFix);

// Determine the set of documents to actually fix. We can also use this to update the progress bar with
// the amount of remaining work to perform. We'll update the progress bar as we compute each fix in
// AddDocumentFixesAsync.
foreach (var (document, diagnosticsToFix) in documentsAndDiagnosticsToFixMap)
{
if (!diagnosticsToFix.IsDefaultOrEmpty)
documentsToFix.Add(document);
}

progressTracker.AddItems(documentsToFix.Count);
var source = documentsAndDiagnosticsToFixMap.Where(kvp => !kvp.Value.IsDefaultOrEmpty).ToImmutableArray();
progressTracker.AddItems(source.Length);

foreach (var document in documentsToFix)
{
var diagnosticsToFix = documentsAndDiagnosticsToFixMap[document];
tasks.Add(AddDocumentFixesAsync(
document, diagnosticsToFix, fixesBag, fixAllState, progressTracker, cancellationToken));
}
using var _ = ArrayBuilder<(Diagnostic diagnostic, CodeAction action)>.GetInstance(out var results);
await ProducerConsumer<(Diagnostic diagnostic, CodeAction action)>.RunParallelAsync(
source,
produceItems: static async (tuple, callback, args, cancellationToken) =>
{
var (document, diagnosticsToFix) = tuple;
try
{
await [email protected](
document, diagnosticsToFix, callback, args.fixAllState, cancellationToken).ConfigureAwait(false);
}
finally
{
args.progressTracker.ItemCompleted();
}
},
consumeItems: static async (stream, args, cancellationToken) =>
{
await foreach (var tuple in stream)
args.results.Add(tuple);
},
args: (@this: this, fixAllState, progressTracker, results),
cancellationToken).ConfigureAwait(false);

await Task.WhenAll(tasks).ConfigureAwait(false);
}

return [.. fixesBag];
}

private async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
FixAllState fixAllState, IProgress<CodeAnalysisProgress> progressTracker, CancellationToken cancellationToken)
{
try
{
await this.AddDocumentFixesAsync(document, diagnostics, fixes, fixAllState, cancellationToken).ConfigureAwait(false);
}
finally
{
progressTracker.ItemCompleted();
return results.ToImmutableAndClear();
}
}

protected virtual async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
Debug.Assert(!diagnostics.IsDefault);
cancellationToken.ThrowIfCancellationRequested();

var registerCodeFix = GetRegisterCodeFixAction(fixAllState, fixes);

var fixerTasks = new List<Task>();
foreach (var diagnostic in diagnostics)
{
cancellationToken.ThrowIfCancellationRequested();
fixerTasks.Add(Task.Run(() =>
var registerCodeFix = GetRegisterCodeFixAction(fixAllState, onItemFound);
await RoslynParallel.ForEachAsync(
source: diagnostics,
cancellationToken,
async (diagnostic, cancellationToken) =>
{
var context = new CodeFixContext(document, diagnostic, registerCodeFix, cancellationToken);

// TODO: Wrap call to ComputeFixesAsync() below in IExtensionManager.PerformFunctionAsync() so that
// a buggy extension that throws can't bring down the host?
return fixAllState.Provider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
}, cancellationToken));
}

await Task.WhenAll(fixerTasks).ConfigureAwait(false);
var task = fixAllState.Provider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
await task.ConfigureAwait(false);
}).ConfigureAwait(false);
}

private async Task<CodeAction?> GetFixAsync(
Expand Down Expand Up @@ -198,7 +186,7 @@ protected virtual async Task AddDocumentFixesAsync(

private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFixAction(
FixAllState fixAllState,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> result)
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound)
{
return (action, diagnostics) =>
{
Expand All @@ -209,7 +197,7 @@ private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFix
if (currentAction is { EquivalenceKey: var equivalenceKey }
&& equivalenceKey == fixAllState.CodeActionEquivalenceKey)
{
result.Add((diagnostics.First(), currentAction));
onItemFound((diagnostics.First(), currentAction));
}

foreach (var nestedAction in currentAction.NestedActions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#nullable disable

using System;
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Linq;
Expand All @@ -26,7 +27,7 @@ internal sealed class PragmaWarningBatchFixAllProvider(AbstractSuppressionCodeFi

protected override async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
var pragmaActionsBuilder = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance();
Expand Down Expand Up @@ -59,7 +60,7 @@ protected override async Task AddDocumentFixesAsync(
pragmaDiagnosticsBuilder.ToImmutableAndFree(),
fixAllState, cancellationToken);

fixes.Add((diagnostic: null, pragmaBatchFix));
onItemFound((diagnostic: null, pragmaBatchFix));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ private sealed class RemoveSuppressionBatchFixAllProvider(AbstractSuppressionCod

protected override async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
// Batch all the pragma remove suppression fixes by executing them sequentially for the document.
var pragmaActionsBuilder = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance();
var pragmaDiagnosticsBuilder = ArrayBuilder<Diagnostic>.GetInstance();
using var _1 = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance(out var pragmaActionsBuilder);
using var _2 = ArrayBuilder<Diagnostic>.GetInstance(out var pragmaDiagnosticsBuilder);

foreach (var diagnostic in diagnostics.Where(d => d.Location.IsInSource && d.IsSuppressed))
{
Expand All @@ -62,7 +62,7 @@ protected override async Task AddDocumentFixesAsync(
}
else
{
fixes.Add((diagnostic, codeAction));
onItemFound((diagnostic, codeAction));
}
}
}
Expand All @@ -73,11 +73,11 @@ protected override async Task AddDocumentFixesAsync(
{
var pragmaBatchFix = PragmaBatchFixHelpers.CreateBatchPragmaFix(
_suppressionFixProvider, document,
pragmaActionsBuilder.ToImmutableAndFree(),
pragmaDiagnosticsBuilder.ToImmutableAndFree(),
pragmaActionsBuilder.ToImmutableAndClear(),
pragmaDiagnosticsBuilder.ToImmutableAndClear(),
fixAllState, cancellationToken);

fixes.Add((diagnostic: null, pragmaBatchFix));
onItemFound((diagnostic: null, pragmaBatchFix));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Microsoft.CodeAnalysis.Options;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;

Expand Down Expand Up @@ -235,16 +236,26 @@ private static async Task<ImmutableArray<CompletionContext>> ComputeNonEmptyComp
SharedSyntaxContextsWithSpeculativeModel sharedContext,
CancellationToken cancellationToken)
{
var completionContextTasks = new List<Task<CompletionContext>>();
foreach (var provider in providers)
{
completionContextTasks.Add(GetContextAsync(
provider, document, caretPosition, trigger,
options, completionListSpan, sharedContext, cancellationToken));
}
using var _ = ArrayBuilder<CompletionContext>.GetInstance(out var results);

await ProducerConsumer<CompletionContext>.RunParallelAsync(
source: providers,
produceItems: static async (provider, callback, args, cancellationToken) =>
{
var context = await GetContextAsync(
provider, args.document, args.caretPosition, args.trigger, args.options, args.completionListSpan, args.sharedContext, cancellationToken).ConfigureAwait(false);
if (HasAnyItems(context))
callback(context);
},
consumeItems: static async (stream, args, cancellationToken) =>
{
await foreach (var result in stream)
args.results.Add(result);
},
args: (document, caretPosition, trigger, options, completionListSpan, sharedContext, results),
cancellationToken).ConfigureAwait(false);

var completionContexts = await Task.WhenAll(completionContextTasks).ConfigureAwait(false);
return completionContexts.Where(HasAnyItems).ToImmutableArray();
return results.ToImmutableAndClear();
}

private CompletionList MergeAndPruneCompletionLists(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.CodeAnalysis.Internal.Log;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Extensions.ContextQuery;
using Microsoft.CodeAnalysis.Shared.Utilities;
Expand Down Expand Up @@ -167,7 +168,11 @@ protected static bool TryFindFirstSymbolMatchesTargetTypes(
return index < symbolList.Length;
}

private static SupportedPlatformData? ComputeSupportedPlatformData(CompletionContext completionContext, ImmutableArray<SymbolAndSelectionInfo> symbols, Dictionary<ISymbol, List<ProjectId>>? invalidProjectMap, List<ProjectId>? totalProjects)
private static SupportedPlatformData? ComputeSupportedPlatformData(
CompletionContext completionContext,
ImmutableArray<SymbolAndSelectionInfo> symbols,
Dictionary<ISymbol, List<ProjectId>>? invalidProjectMap,
List<ProjectId>? totalProjects)
{
SupportedPlatformData? supportedPlatformData = null;
if (invalidProjectMap != null)
Expand Down Expand Up @@ -271,7 +276,19 @@ private async Task<ImmutableArray<CompletionItem>> GetItemsAsync(
return CreateItems(completionContext, itemsForCurrentDocument, _ => syntaxContext, invalidProjectMap: null, totalProjects: null);
}

var contextAndSymbolLists = await GetPerContextSymbolsAsync(completionContext, document, options, new[] { document.Id }.Concat(relatedDocumentIds), cancellationToken).ConfigureAwait(false);
using var _ = PooledDictionary<DocumentId, int>.GetInstance(out var documentIdToIndex);
documentIdToIndex.Add(document.Id, 0);
foreach (var documentId in relatedDocumentIds)
documentIdToIndex.Add(documentId, documentIdToIndex.Count);

var contextAndSymbolLists = await GetPerContextSymbolsAsync(completionContext, document, options, documentIdToIndex.Keys, cancellationToken).ConfigureAwait(false);

// We want the resultant contexts ordered in the same order the related documents came in. Importantly, the
// context for *our* starting document should be placed first.
contextAndSymbolLists = contextAndSymbolLists
.OrderBy((tuple1, tuple2) => documentIdToIndex[tuple1.documentId] - documentIdToIndex[tuple2.documentId])
.ToImmutableArray();

var symbolToContextMap = UnionSymbols(contextAndSymbolLists);
var missingSymbolsMap = FindSymbolsMissingInLinkedContexts(symbolToContextMap, contextAndSymbolLists);
var totalProjects = contextAndSymbolLists.Select(t => t.documentId.ProjectId).ToList();
Expand All @@ -297,10 +314,7 @@ private static Dictionary<SymbolAndSelectionInfo, TSyntaxContext> UnionSymbols(
// We need to use the SemanticModel any particular symbol came from in order to generate its description correctly.
// Therefore, when we add a symbol to set of union symbols, add a mapping from it to its SyntaxContext.
foreach (var symbol in symbols.GroupBy(s => new { s.Symbol.Name, s.Symbol.Kind }).Select(g => g.First()))
{
if (!result.ContainsKey(symbol))
result.Add(symbol, syntaxContext);
}
result.TryAdd(symbol, syntaxContext);
}

return result;
Expand All @@ -311,31 +325,30 @@ private static Dictionary<SymbolAndSelectionInfo, TSyntaxContext> UnionSymbols(
{
var solution = document.Project.Solution;

using var _1 = ArrayBuilder<Task<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>>.GetInstance(out var tasks);
using var _2 = ArrayBuilder<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.GetInstance(out var perContextSymbols);
using var _1 = ArrayBuilder<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.GetInstance(out var perContextSymbols);

foreach (var relatedDocumentId in relatedDocuments)
{
tasks.Add(Task.Run(async () =>
await ProducerConsumer<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.RunParallelAsync(
source: relatedDocuments,
produceItems: static async (relatedDocumentId, callback, args, cancellationToken) =>
{
var relatedDocument = solution.GetRequiredDocument(relatedDocumentId);
var syntaxContext = await completionContext.GetSyntaxContextWithExistingSpeculativeModelAsync(relatedDocument, cancellationToken).ConfigureAwait(false) as TSyntaxContext;
var relatedDocument = args.solution.GetRequiredDocument(relatedDocumentId);
var syntaxContext = await args.completionContext.GetSyntaxContextWithExistingSpeculativeModelAsync(
relatedDocument, cancellationToken).ConfigureAwait(false) as TSyntaxContext;

Contract.ThrowIfNull(syntaxContext);
var symbols = await TryGetSymbolsForContextAsync(completionContext, syntaxContext, options, cancellationToken).ConfigureAwait(false);

return (relatedDocument.Id, syntaxContext, symbols);
}, cancellationToken));
}
var symbols = await [email protected](
args.completionContext, syntaxContext, args.options, cancellationToken).ConfigureAwait(false);

await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var task in tasks)
{
var (relatedDocumentId, syntaxContext, symbols) = await task.ConfigureAwait(false);
if (!symbols.IsDefault)
perContextSymbols.Add((relatedDocumentId, syntaxContext, symbols));
}
if (!symbols.IsDefault)
callback((relatedDocument.Id, syntaxContext, symbols));
},
consumeItems: static async (results, args, cancellationToken) =>
{
await foreach (var tuple in results)
args.perContextSymbols.Add(tuple);
},
args: (@this: this, solution, completionContext, options, perContextSymbols),
cancellationToken).ConfigureAwait(false);

return perContextSymbols.ToImmutableAndClear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ public IWorkspaceService CreateService(HostWorkspaceServices workspaceServices)
}

var workQueue = new AsyncBatchingWorkQueue<Project>(
TimeSpan.FromSeconds(1),
_processBatchAsync,
_listenerProvider.GetListener(FeatureAttribute.CompletionSet),
_disposalToken);
TimeSpan.FromSeconds(1),
_processBatchAsync,
_listenerProvider.GetListener(FeatureAttribute.CompletionSet),
_disposalToken);

return new ImportCompletionCacheService(
_peItemsCache, _projectItemsCache, workQueue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ namespace Microsoft.CodeAnalysis.Completion.Providers;
[method: ImportingConstructor]
[method: Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
internal sealed class DefaultExtensionMethodImportCompletionCacheServiceFactory(IAsynchronousOperationListenerProvider listenerProvider)
: AbstractImportCompletionCacheServiceFactory<ExtensionMethodImportCompletionCacheEntry, object>(listenerProvider, ExtensionMethodImportCompletionHelper.BatchUpdateCacheAsync, CancellationToken.None)
: AbstractImportCompletionCacheServiceFactory<ExtensionMethodImportCompletionCacheEntry, object>(listenerProvider, ExtensionMethodImportCompletionHelper.BatchUpdateCacheAsync, CancellationToken.None)
{
}
Loading
Loading