Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more code over to producer/consumer model #73331

Merged
merged 17 commits into from
May 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;

Expand Down Expand Up @@ -76,7 +77,6 @@ internal abstract class AbstractSuppressionBatchFixAllProvider : FixAllProvider
{
var cancellationToken = fixAllContext.CancellationToken;
var fixAllState = fixAllContext.State;
var fixesBag = new ConcurrentBag<(Diagnostic diagnostic, CodeAction action)>();

using (Logger.LogBlock(
FunctionId.CodeFixes_FixAllOccurrencesComputation_Document_Fixes,
Expand All @@ -86,41 +86,41 @@ internal abstract class AbstractSuppressionBatchFixAllProvider : FixAllProvider
cancellationToken.ThrowIfCancellationRequested();
var progressTracker = fixAllContext.Progress;

using var _1 = ArrayBuilder<Task>.GetInstance(out var tasks);
using var _2 = ArrayBuilder<Document>.GetInstance(out var documentsToFix);

// Determine the set of documents to actually fix. We can also use this to update the progress bar with
// the amount of remaining work to perform. We'll update the progress bar as we compute each fix in
// AddDocumentFixesAsync.
foreach (var (document, diagnosticsToFix) in documentsAndDiagnosticsToFixMap)
{
if (!diagnosticsToFix.IsDefaultOrEmpty)
documentsToFix.Add(document);
}

progressTracker.AddItems(documentsToFix.Count);
var source = documentsAndDiagnosticsToFixMap.Where(kvp => !kvp.Value.IsDefaultOrEmpty).ToImmutableArray();
progressTracker.AddItems(source.Length);

foreach (var document in documentsToFix)
{
var diagnosticsToFix = documentsAndDiagnosticsToFixMap[document];
tasks.Add(AddDocumentFixesAsync(
document, diagnosticsToFix, fixesBag, fixAllState, progressTracker, cancellationToken));
}
using var _ = ArrayBuilder<(Diagnostic diagnostic, CodeAction action)>.GetInstance(out var results);
await ProducerConsumer<(Diagnostic diagnostic, CodeAction action)>.RunParallelAsync(
source,
produceItems: static async (tuple, callback, args, cancellationToken) =>
{
var (document, diagnosticsToFix) = tuple;
await [email protected](
document, diagnosticsToFix, callback, args.fixAllState, args.progressTracker, cancellationToken).ConfigureAwait(false);
},
consumeItems: static async (stream, args, cancellationToken) =>
{
await foreach (var tuple in stream)
args.results.Add(tuple);
},
args: (@this: this, fixAllState, progressTracker, results),
cancellationToken).ConfigureAwait(false);

await Task.WhenAll(tasks).ConfigureAwait(false);
return results.ToImmutableAndClear();
}

return [.. fixesBag];
}

private async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, IProgress<CodeAnalysisProgress> progressTracker, CancellationToken cancellationToken)
{
try
{
await this.AddDocumentFixesAsync(document, diagnostics, fixes, fixAllState, cancellationToken).ConfigureAwait(false);
await this.AddDocumentFixesAsync(document, diagnostics, onItemFound, fixAllState, cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -130,29 +130,25 @@ private async Task AddDocumentFixesAsync(

protected virtual async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
Debug.Assert(!diagnostics.IsDefault);
cancellationToken.ThrowIfCancellationRequested();

var registerCodeFix = GetRegisterCodeFixAction(fixAllState, fixes);

var fixerTasks = new List<Task>();
foreach (var diagnostic in diagnostics)
{
cancellationToken.ThrowIfCancellationRequested();
fixerTasks.Add(Task.Run(() =>
var registerCodeFix = GetRegisterCodeFixAction(fixAllState, onItemFound);
await RoslynParallel.ForEachAsync(
source: diagnostics,
cancellationToken,
async (diagnostic, cancellationToken) =>
{
var context = new CodeFixContext(document, diagnostic, registerCodeFix, cancellationToken);

// TODO: Wrap call to ComputeFixesAsync() below in IExtensionManager.PerformFunctionAsync() so that
// a buggy extension that throws can't bring down the host?
return fixAllState.Provider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
}, cancellationToken));
}

await Task.WhenAll(fixerTasks).ConfigureAwait(false);
var task = fixAllState.Provider.RegisterCodeFixesAsync(context) ?? Task.CompletedTask;
await task.ConfigureAwait(false);
}).ConfigureAwait(false);
}

private async Task<CodeAction?> GetFixAsync(
Expand Down Expand Up @@ -198,7 +194,7 @@ protected virtual async Task AddDocumentFixesAsync(

private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFixAction(
FixAllState fixAllState,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> result)
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound)
{
return (action, diagnostics) =>
{
Expand All @@ -209,7 +205,7 @@ private static Action<CodeAction, ImmutableArray<Diagnostic>> GetRegisterCodeFix
if (currentAction is { EquivalenceKey: var equivalenceKey }
&& equivalenceKey == fixAllState.CodeActionEquivalenceKey)
{
result.Add((diagnostics.First(), currentAction));
onItemFound((diagnostics.First(), currentAction));
}

foreach (var nestedAction in currentAction.NestedActions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#nullable disable

using System;
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Linq;
Expand All @@ -26,7 +27,7 @@ internal sealed class PragmaWarningBatchFixAllProvider(AbstractSuppressionCodeFi

protected override async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
var pragmaActionsBuilder = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance();
Expand Down Expand Up @@ -59,7 +60,7 @@ protected override async Task AddDocumentFixesAsync(
pragmaDiagnosticsBuilder.ToImmutableAndFree(),
fixAllState, cancellationToken);

fixes.Add((diagnostic: null, pragmaBatchFix));
onItemFound((diagnostic: null, pragmaBatchFix));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ private sealed class RemoveSuppressionBatchFixAllProvider(AbstractSuppressionCod

protected override async Task AddDocumentFixesAsync(
Document document, ImmutableArray<Diagnostic> diagnostics,
ConcurrentBag<(Diagnostic diagnostic, CodeAction action)> fixes,
Action<(Diagnostic diagnostic, CodeAction action)> onItemFound,
FixAllState fixAllState, CancellationToken cancellationToken)
{
// Batch all the pragma remove suppression fixes by executing them sequentially for the document.
var pragmaActionsBuilder = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance();
var pragmaDiagnosticsBuilder = ArrayBuilder<Diagnostic>.GetInstance();
using var _1 = ArrayBuilder<IPragmaBasedCodeAction>.GetInstance(out var pragmaActionsBuilder);
using var _2 = ArrayBuilder<Diagnostic>.GetInstance(out var pragmaDiagnosticsBuilder);

foreach (var diagnostic in diagnostics.Where(d => d.Location.IsInSource && d.IsSuppressed))
{
Expand All @@ -62,7 +62,7 @@ protected override async Task AddDocumentFixesAsync(
}
else
{
fixes.Add((diagnostic, codeAction));
onItemFound((diagnostic, codeAction));
}
}
}
Expand All @@ -73,11 +73,11 @@ protected override async Task AddDocumentFixesAsync(
{
var pragmaBatchFix = PragmaBatchFixHelpers.CreateBatchPragmaFix(
_suppressionFixProvider, document,
pragmaActionsBuilder.ToImmutableAndFree(),
pragmaDiagnosticsBuilder.ToImmutableAndFree(),
pragmaActionsBuilder.ToImmutableAndClear(),
pragmaDiagnosticsBuilder.ToImmutableAndClear(),
fixAllState, cancellationToken);

fixes.Add((diagnostic: null, pragmaBatchFix));
onItemFound((diagnostic: null, pragmaBatchFix));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ private async Task<ImmutableArray<CompletionItem>> GetItemsAsync(
CompletionOptions options,
CancellationToken cancellationToken)
{
var relatedDocumentIds = document.GetLinkedDocumentIds();
var relatedDocumentIds = document.Project.Solution.GetRelatedDocumentIds(document.Id);

if (relatedDocumentIds.IsEmpty)
if (relatedDocumentIds.Length == 1)
{
var itemsForCurrentDocument = await GetSymbolsAsync(completionContext, syntaxContext, position, options, cancellationToken).ConfigureAwait(false);
return CreateItems(completionContext, itemsForCurrentDocument, _ => syntaxContext, invalidProjectMap: null, totalProjects: null);
}

var contextAndSymbolLists = await GetPerContextSymbolsAsync(completionContext, document, options, new[] { document.Id }.Concat(relatedDocumentIds), cancellationToken).ConfigureAwait(false);
var contextAndSymbolLists = await GetPerContextSymbolsAsync(completionContext, document, options, relatedDocumentIds, cancellationToken).ConfigureAwait(false);
var symbolToContextMap = UnionSymbols(contextAndSymbolLists);
var missingSymbolsMap = FindSymbolsMissingInLinkedContexts(symbolToContextMap, contextAndSymbolLists);
var totalProjects = contextAndSymbolLists.Select(t => t.documentId.ProjectId).ToList();
Expand All @@ -297,45 +297,41 @@ private static Dictionary<SymbolAndSelectionInfo, TSyntaxContext> UnionSymbols(
// We need to use the SemanticModel any particular symbol came from in order to generate its description correctly.
// Therefore, when we add a symbol to set of union symbols, add a mapping from it to its SyntaxContext.
foreach (var symbol in symbols.GroupBy(s => new { s.Symbol.Name, s.Symbol.Kind }).Select(g => g.First()))
{
if (!result.ContainsKey(symbol))
result.Add(symbol, syntaxContext);
}
result.TryAdd(symbol, syntaxContext);
}

return result;
}

private async Task<ImmutableArray<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>> GetPerContextSymbolsAsync(
CompletionContext completionContext, Document document, CompletionOptions options, IEnumerable<DocumentId> relatedDocuments, CancellationToken cancellationToken)
CompletionContext completionContext, Document document, CompletionOptions options, ImmutableArray<DocumentId> relatedDocuments, CancellationToken cancellationToken)
{
var solution = document.Project.Solution;

using var _1 = ArrayBuilder<Task<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>>.GetInstance(out var tasks);
using var _2 = ArrayBuilder<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.GetInstance(out var perContextSymbols);
using var _ = ArrayBuilder<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.GetInstance(out var perContextSymbols);

foreach (var relatedDocumentId in relatedDocuments)
{
tasks.Add(Task.Run(async () =>
await ProducerConsumer<(DocumentId documentId, TSyntaxContext syntaxContext, ImmutableArray<SymbolAndSelectionInfo> symbols)>.RunParallelAsync(
source: relatedDocuments,
produceItems: static async (relatedDocumentId, callback, args, cancellationToken) =>
{
var relatedDocument = solution.GetRequiredDocument(relatedDocumentId);
var syntaxContext = await completionContext.GetSyntaxContextWithExistingSpeculativeModelAsync(relatedDocument, cancellationToken).ConfigureAwait(false) as TSyntaxContext;
var relatedDocument = args.solution.GetRequiredDocument(relatedDocumentId);
var syntaxContext = await args.completionContext.GetSyntaxContextWithExistingSpeculativeModelAsync(
relatedDocument, cancellationToken).ConfigureAwait(false) as TSyntaxContext;

Contract.ThrowIfNull(syntaxContext);
var symbols = await TryGetSymbolsForContextAsync(completionContext, syntaxContext, options, cancellationToken).ConfigureAwait(false);

return (relatedDocument.Id, syntaxContext, symbols);
}, cancellationToken));
}
var symbols = await [email protected](
args.completionContext, syntaxContext, args.options, cancellationToken).ConfigureAwait(false);

await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var task in tasks)
{
var (relatedDocumentId, syntaxContext, symbols) = await task.ConfigureAwait(false);
if (!symbols.IsDefault)
perContextSymbols.Add((relatedDocumentId, syntaxContext, symbols));
}
if (!symbols.IsDefault)
callback((relatedDocument.Id, syntaxContext, symbols));
},
consumeItems: static async (results, args, cancellationToken) =>
{
await foreach (var tuple in results)
args.perContextSymbols.Add(tuple);
},
args: (@this: this, solution, completionContext, options, perContextSymbols),
cancellationToken).ConfigureAwait(false);

return perContextSymbols.ToImmutableAndClear();
}
Expand Down
5 changes: 4 additions & 1 deletion src/Features/Lsif/Generator/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ public async Task GenerateForProjectAsync(
}
};

var documents = (await project.GetAllRegularAndSourceGeneratedDocumentsAsync(cancellationToken)).ToList();
var documents = new List<Document>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var documents = new List();

Was something wrong here? Was the loop below intended to be changed to use the new parallel methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. i just needed to update this code because GetAllRegularAndSourceGeneratedDocumentsAsync now returns an IAsyncEnumerable.

await foreach (var document in project.GetAllRegularAndSourceGeneratedDocumentsAsync(cancellationToken))
documents.Add(document);

var tasks = new List<Task>();
foreach (var document in documents)
{
Expand Down
Loading
Loading