Skip to content

Commit

Permalink
Merge pull request #73324 from CyrusNajmabadi/taskWhenAll
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored May 3, 2024
2 parents 27a77e2 + 418382c commit 6128e20
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ public async Task SearchCachedDocumentsAsync(

Debug.Assert(priorityDocuments.All(d => projects.Contains(d.Project)));

var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound, cancellationToken);
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound);

var documentKeys = projects.SelectManyAsArray(p => p.Documents.Select(DocumentKey.ToDocumentKey));
var priorityDocumentKeys = priorityDocuments.SelectAsArray(DocumentKey.ToDocumentKey);

var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false);
if (client != null)
{
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted);
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted, cancellationToken);
await client.TryInvokeAsync<IRemoteNavigateToSearchService>(
(service, callbackId, cancellationToken) =>
service.SearchCachedDocumentsAsync(documentKeys, priorityDocumentKeys, searchPattern, [.. kinds], callbackId, cancellationToken),
Expand All @@ -101,7 +101,7 @@ public static async Task SearchCachedDocumentsInCurrentProcessAsync(
ImmutableArray<DocumentKey> priorityDocumentKeys,
string searchPattern,
IImmutableSet<string> kinds,
Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound,
Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> onItemsFound,
Func<Task> onProjectCompleted,
CancellationToken cancellationToken)
{
Expand All @@ -120,14 +120,15 @@ public static async Task SearchCachedDocumentsInCurrentProcessAsync(

// Sort the groups into a high pri group (projects that contain a high-pri doc), and low pri groups (those that
// don't), and process in that order.
await PerformParallelSearchAsync(
await ProducerConsumer<RoslynNavigateToItem>.RunParallelAsync(
Prioritize(groups, g => g.Any(priorityDocumentKeysSet.Contains)),
ProcessSingleProjectGroupAsync, onItemsFound, cancellationToken).ConfigureAwait(false);
ProcessSingleProjectGroupAsync, onItemsFound, args: default, cancellationToken).ConfigureAwait(false);
return;

async ValueTask ProcessSingleProjectGroupAsync(
async Task ProcessSingleProjectGroupAsync(
IGrouping<ProjectKey, DocumentKey> group,
Action<RoslynNavigateToItem> onItemFound,
VoidResult _,
CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ public async Task SearchGeneratedDocumentsAsync(
Contract.ThrowIfTrue(projects.IsEmpty);
Contract.ThrowIfTrue(projects.Select(p => p.Language).Distinct().Count() != 1);

var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound, cancellationToken);
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound);

var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false);
if (client != null)
{
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted);
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted, cancellationToken);

await client.TryInvokeAsync<IRemoteNavigateToSearchService>(
// Sync and search the full solution snapshot. While this function is called serially per project,
Expand All @@ -60,18 +60,19 @@ public static async Task SearchGeneratedDocumentsInCurrentProcessAsync(
ImmutableArray<Project> projects,
string pattern,
IImmutableSet<string> kinds,
Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound,
Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> onItemsFound,
Func<Task> onProjectCompleted,
CancellationToken cancellationToken)
{
var (patternName, patternContainerOpt) = PatternMatcher.GetNameAndContainer(pattern);
var declaredSymbolInfoKindsSet = new DeclaredSymbolInfoKindSet(kinds);

await PerformParallelSearchAsync(projects, ProcessSingleProjectAsync, onItemsFound, cancellationToken).ConfigureAwait(false);
await ProducerConsumer<RoslynNavigateToItem>.RunParallelAsync(
projects, ProcessSingleProjectAsync, onItemsFound, args: default, cancellationToken).ConfigureAwait(false);
return;

async ValueTask ProcessSingleProjectAsync(
Project project, Action<RoslynNavigateToItem> onItemFound, CancellationToken cancellationToken)
async Task ProcessSingleProjectAsync(
Project project, Action<RoslynNavigateToItem> onItemFound, VoidResult _, CancellationToken cancellationToken)
{
// First generate all the source-gen docs. Then handoff to the standard search routine to find matches in them.
var sourceGeneratedDocs = await project.GetSourceGeneratedDocumentsAsync(cancellationToken).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ public async Task SearchDocumentAsync(
CancellationToken cancellationToken)
{
var solution = document.Project.Solution;
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument: null, onResultsFound, cancellationToken);
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument: null, onResultsFound);

var client = await RemoteHostClient.TryGetClientAsync(document.Project, cancellationToken).ConfigureAwait(false);
if (client != null)
{
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted: null);
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted: null, cancellationToken);
// Don't need to sync the full solution when searching a single document. Just sync the project that doc is in.
await client.TryInvokeAsync<IRemoteNavigateToSearchService>(
document.Project,
Expand All @@ -45,7 +45,11 @@ await client.TryInvokeAsync<IRemoteNavigateToSearchService>(
}

public static async Task SearchDocumentInCurrentProcessAsync(
Document document, string searchPattern, IImmutableSet<string> kinds, Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound, CancellationToken cancellationToken)
Document document,
string searchPattern,
IImmutableSet<string> kinds,
Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> onItemsFound,
CancellationToken cancellationToken)
{
var (patternName, patternContainerOpt) = PatternMatcher.GetNameAndContainer(searchPattern);
var declaredSymbolInfoKindsSet = new DeclaredSymbolInfoKindSet(kinds);
Expand All @@ -55,7 +59,7 @@ await SearchSingleDocumentAsync(
document, patternName, patternContainerOpt, declaredSymbolInfoKindsSet, t => results.Add(t), cancellationToken).ConfigureAwait(false);

if (results.Count > 0)
await onItemsFound(results.ToImmutableArray()).ConfigureAwait(false);
await onItemsFound(results.ToImmutableArray(), default, cancellationToken).ConfigureAwait(false);
}

public async Task SearchProjectsAsync(
Expand All @@ -76,13 +80,13 @@ public async Task SearchProjectsAsync(
Contract.ThrowIfTrue(projects.Select(p => p.Language).Distinct().Count() != 1);

Debug.Assert(priorityDocuments.All(d => projects.Contains(d.Project)));
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound, cancellationToken);
var onItemsFound = GetOnItemsFoundCallback(solution, activeDocument, onResultsFound);

var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false);
if (client != null)
{
var priorityDocumentIds = priorityDocuments.SelectAsArray(d => d.Id);
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted);
var callback = new NavigateToSearchServiceCallback(onItemsFound, onProjectCompleted, cancellationToken);

await client.TryInvokeAsync<IRemoteNavigateToSearchService>(
// Intentionally sync the full solution. When SearchProjectAsync is called, we're searching all
Expand All @@ -105,7 +109,7 @@ public static async Task SearchProjectsInCurrentProcessAsync(
ImmutableArray<Document> priorityDocuments,
string searchPattern,
IImmutableSet<string> kinds,
Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound,
Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> onItemsFound,
Func<Task> onProjectCompleted,
CancellationToken cancellationToken)
{
Expand All @@ -120,17 +124,18 @@ public static async Task SearchProjectsInCurrentProcessAsync(

// Process each project on its own. That way we can tell the client when we are done searching it. Put the
// projects with priority documents ahead of those without so we can get results for those faster.
await PerformParallelSearchAsync(
await ProducerConsumer<RoslynNavigateToItem>.RunParallelAsync(
Prioritize(projects, highPriProjects.Contains),
SearchSingleProjectAsync, onItemsFound, cancellationToken).ConfigureAwait(false);
SearchSingleProjectAsync, onItemsFound, args: default, cancellationToken).ConfigureAwait(false);
return;

async ValueTask SearchSingleProjectAsync(
async Task SearchSingleProjectAsync(
Project project,
Action<RoslynNavigateToItem> onItemFound,
VoidResult _,
CancellationToken cancellationToken)
{
using var _ = GetPooledHashSet(priorityDocuments.Where(d => project == d.Project), out var highPriDocs);
using var _1 = GetPooledHashSet(priorityDocuments.Where(d => project == d.Project), out var highPriDocs);

await RoslynParallel.ForEachAsync(
Prioritize(project.Documents, highPriDocs.Contains),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ internal abstract partial class AbstractNavigateToSearchService : IAdvancedNavig

public bool CanFilter => true;

private static Func<ImmutableArray<RoslynNavigateToItem>, Task> GetOnItemsFoundCallback(
Solution solution, Document? activeDocument, Func<ImmutableArray<INavigateToSearchResult>, Task> onResultsFound, CancellationToken cancellationToken)
private static Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> GetOnItemsFoundCallback(
Solution solution, Document? activeDocument, Func<ImmutableArray<INavigateToSearchResult>, Task> onResultsFound)
{
return async items =>
return async (items, _, cancellationToken) =>
{
using var _ = ArrayBuilder<INavigateToSearchResult>.GetInstance(items.Length, out var results);
using var _1 = ArrayBuilder<INavigateToSearchResult>.GetInstance(items.Length, out var results);

foreach (var item in items)
{
Expand Down Expand Up @@ -81,22 +81,4 @@ private static IEnumerable<T> Prioritize<T>(IEnumerable<T> items, Func<T, bool>
foreach (var item in normalItems)
yield return item;
}

/// <summary>
/// Main utility for searching across items in a solution. The actual code to search the item should be provided in
/// <paramref name="callback"/>. Each item in <paramref name="items"/> will be processed using
/// <code>Parallel.ForEachAsync</code>, allowing for parallel processing of the items, with a preference towards
/// earlier items.
/// </summary>
private static Task PerformParallelSearchAsync<T>(
IEnumerable<T> items,
Func<T, Action<RoslynNavigateToItem>, CancellationToken, ValueTask> callback,
Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound,
CancellationToken cancellationToken)
=> ProducerConsumer<RoslynNavigateToItem>.RunParallelAsync(
source: items,
produceItems: static async (item, onItemFound, args, cancellationToken) => await args.callback(item, onItemFound, cancellationToken).ConfigureAwait(false),
consumeItems: static (items, args, cancellationToken) => args.onItemsFound(items),
args: (items, callback, onItemsFound),
cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.Remote;
using Microsoft.CodeAnalysis.Storage;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.NavigateTo;

Expand Down Expand Up @@ -47,14 +48,15 @@ public ValueTask OnProjectCompletedAsync(RemoteServiceCallbackId callbackId)
}

internal sealed class NavigateToSearchServiceCallback(
Func<ImmutableArray<RoslynNavigateToItem>, Task> onItemsFound,
Func<Task>? onProjectCompleted)
Func<ImmutableArray<RoslynNavigateToItem>, VoidResult, CancellationToken, Task> onItemsFound,
Func<Task>? onProjectCompleted,
CancellationToken cancellationToken)
{
public async ValueTask OnItemsFoundAsync(ImmutableArray<RoslynNavigateToItem> items)
{
try
{
await onItemsFound(items).ConfigureAwait(false);
await onItemsFound(items, default, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (FatalError.ReportAndPropagateUnlessCanceled(ex))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixesAndRefactorings;
Expand Down Expand Up @@ -109,40 +110,35 @@ private static async Task<ImmutableDictionary<Document, ImmutableArray<Diagnosti
var cancellationToken = fixAllContext.CancellationToken;

using var _1 = progressTracker.ItemCompletedScope();
using var _2 = ArrayBuilder<Task<(DocumentId, (SyntaxNode? node, SourceText? text))>>.GetInstance(out var tasks);

var docIdToNewRootOrText = new Dictionary<DocumentId, (SyntaxNode? node, SourceText? text)>();
if (!diagnostics.IsEmpty)
{
// Then, process all documents in parallel to get the change for each doc.
foreach (var (document, documentDiagnostics) in diagnostics)
{
if (documentDiagnostics.IsDefaultOrEmpty)
continue;

tasks.Add(Task.Run(async () =>
await ProducerConsumer<(DocumentId, (SyntaxNode? node, SourceText? text))>.RunParallelAsync(
source: diagnostics.Where(kvp => !kvp.Value.IsDefaultOrEmpty),
produceItems: static async (kvp, callback, args, cancellationToken) =>
{
var newDocument = await this.FixAllAsync(fixAllContext, document, documentDiagnostics).ConfigureAwait(false);
var (document, documentDiagnostics) = kvp;

var newDocument = await args.@this.FixAllAsync(args.fixAllContext, document, documentDiagnostics).ConfigureAwait(false);
if (newDocument == null || newDocument == document)
return default;
return;

// For documents that support syntax, grab the tree so that we can clean it up later. If it's a
// language that doesn't support that, then just grab the text.
var node = newDocument.SupportsSyntaxTree ? await newDocument.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false) : null;
var text = newDocument.SupportsSyntaxTree ? null : await newDocument.GetValueTextAsync(cancellationToken).ConfigureAwait(false);

return (document.Id, (node, text));
}, cancellationToken));
}

await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var task in tasks)
{
var (docId, nodeOrText) = await task.ConfigureAwait(false);
if (docId != null)
docIdToNewRootOrText[docId] = nodeOrText;
}
callback((document.Id, (node, text)));
},
consumeItems: static async (results, args, cancellationToken) =>
{
await foreach (var (docId, nodeOrText) in results)
args.docIdToNewRootOrText[docId] = nodeOrText;
},
args: (@this: this, fixAllContext, docIdToNewRootOrText),
cancellationToken).ConfigureAwait(false);
}

return docIdToNewRootOrText;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeFixesAndRefactorings;
using Microsoft.CodeAnalysis.Internal.Log;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CodeFixes;
Expand Down Expand Up @@ -83,22 +84,23 @@ internal static async Task<ImmutableDictionary<Project, ImmutableArray<Diagnosti
case FixAllScope.Solution:
var projectsAndDiagnostics = ImmutableDictionary.CreateBuilder<Project, ImmutableArray<Diagnostic>>();

var tasks = project.Solution.Projects.Select(async p => new
{
Project = p,
Diagnostics = await fixAllContext.GetProjectDiagnosticsAsync(p).ConfigureAwait(false)
}).ToArray();

await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var task in tasks)
{
var projectAndDiagnostics = await task.ConfigureAwait(false);
if (projectAndDiagnostics.Diagnostics.Any())
await ProducerConsumer<(Project project, ImmutableArray<Diagnostic> diagnostics)>.RunParallelAsync(
source: project.Solution.Projects,
produceItems: static async (project, callback, args, cancellationToken) =>
{
var diagnostics = await args.fixAllContext.GetProjectDiagnosticsAsync(project).ConfigureAwait(false);
callback((project, diagnostics));
},
consumeItems: static async (results, args, cancellationToken) =>
{
projectsAndDiagnostics[projectAndDiagnostics.Project] = projectAndDiagnostics.Diagnostics;
}
}
await foreach (var (project, diagnostics) in results)
{
if (diagnostics.Any())
args.projectsAndDiagnostics.Add(project, diagnostics);
}
},
args: (fixAllContext, projectsAndDiagnostics),
fixAllContext.CancellationToken).ConfigureAwait(false);

return projectsAndDiagnostics.ToImmutable();
}
Expand Down
Loading

0 comments on commit 6128e20

Please sign in to comment.