Skip to content

Commit

Permalink
Improve MA0100 to detect all awaitable types, not only Task and Value…
Browse files Browse the repository at this point in the history
…Task (#507)
  • Loading branch information
meziantou authored Apr 26, 2023
1 parent df707da commit 0395ec6
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/Rules/MA0100.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# MA0100 - Await task before disposing of resources

The rule detects `Task`, `Task<T>`, `ValueTask`, `ValueTask<T>`, or any type that follows the awaitable pattern.

````csharp
using System;
using System.Threading.Tasks;
Expand Down
115 changes: 115 additions & 0 deletions src/Meziantou.Analyzer/Internals/AwaitableTypes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
using System;
using System.Collections.Generic;
using Microsoft.CodeAnalysis;

namespace Meziantou.Analyzer.Internals;
internal sealed class AwaitableTypes
{
private readonly INamedTypeSymbol[] _taskLikeSymbols;

public AwaitableTypes(Compilation compilation)
{
INotifyCompletionSymbol = compilation.GetBestTypeByMetadataName("System.Runtime.CompilerServices.INotifyCompletion");

if (INotifyCompletionSymbol != null)
{
var taskLikeSymbols = new List<INamedTypeSymbol>(4);
taskLikeSymbols.AddIfNotNull(compilation.GetBestTypeByMetadataName("System.Threading.Tasks.Task"));
taskLikeSymbols.AddIfNotNull(compilation.GetBestTypeByMetadataName("System.Threading.Tasks.Task`1"));
taskLikeSymbols.AddIfNotNull(compilation.GetBestTypeByMetadataName("System.Threading.Tasks.ValueTask"));
taskLikeSymbols.AddIfNotNull(compilation.GetBestTypeByMetadataName("System.Threading.Tasks.ValueTask`1"));
_taskLikeSymbols = taskLikeSymbols.ToArray();
}
else
{
_taskLikeSymbols = Array.Empty<INamedTypeSymbol>();
}
}

private INamedTypeSymbol? INotifyCompletionSymbol { get; }

// https://github.com/dotnet/roslyn/blob/248e85149427c534c4a156a436ecff69bab83b59/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs#L347
public bool IsAwaitable(ITypeSymbol? symbol, SemanticModel semanticModel, int position)
{
if (symbol == null)
return false;

if (INotifyCompletionSymbol == null)
return false;

if (symbol.SpecialType is SpecialType.System_Void || symbol.TypeKind is TypeKind.Dynamic)
return false;

if (IsTaskLike(symbol))
return true;

foreach (var potentialSymbol in semanticModel.LookupSymbols(position, container: symbol, name: "GetAwaiter", includeReducedExtensionMethods: true))
{
if (potentialSymbol is not IMethodSymbol getAwaiterMethod)
continue;

if (!semanticModel.IsAccessible(position, getAwaiterMethod))
continue;

if (!getAwaiterMethod.Parameters.IsEmpty)
continue;

if (!ConformsToAwaiterPattern(getAwaiterMethod.ReturnType))
continue;

return true;
}

return false;
}

private bool IsTaskLike(ITypeSymbol? symbol)
{
if (symbol is null)
return false;

var originalDefinition = symbol.OriginalDefinition;
foreach (var taskLikeSymbol in _taskLikeSymbols)
{
if (originalDefinition.IsEqualTo(taskLikeSymbol))
return true;
}

return false;
}

private bool ConformsToAwaiterPattern(ITypeSymbol typeSymbol)
{
if (typeSymbol is null)
return false;

var hasGetResultMethod = false;
var hasIsCompletedProperty = false;

if (!typeSymbol.Implements(INotifyCompletionSymbol))
return false;

foreach (var member in typeSymbol.GetMembers())
{
if (member is IMethodSymbol { Name: "GetResult", Parameters.IsEmpty: true, TypeParameters.IsEmpty: true, IsStatic: false })
{
hasGetResultMethod = true;
}
else if (member is IPropertySymbol { Name: "IsCompleted", IsStatic: false, Type.SpecialType: SpecialType.System_Boolean, GetMethod: not null })
{
hasIsCompletedProperty = true;
}
else
{
continue;
}

if (hasGetResultMethod && hasIsCompletedProperty)
{
return true;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.Immutable;
using System.Threading.Tasks;
using Meziantou.Analyzer.Internals;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
Expand Down Expand Up @@ -36,21 +36,16 @@ public override void Initialize(AnalysisContext context)

private sealed class AnalyzerContext
{
private readonly INamedTypeSymbol[] _taskLikeSymbols;
private readonly AwaitableTypes _awaitableTypes;

public AnalyzerContext(Compilation compilation)
{
_awaitableTypes = new AwaitableTypes(compilation);

TaskSymbol = compilation.GetBestTypeByMetadataName("System.Threading.Tasks.Task");
TaskOfTSymbol = compilation.GetBestTypeByMetadataName("System.Threading.Tasks.Task`1");
ValueTaskSymbol = compilation.GetBestTypeByMetadataName("System.Threading.Tasks.ValueTask");
ValueTaskOfTSymbol = compilation.GetBestTypeByMetadataName("System.Threading.Tasks.ValueTask`1");

var taskLikeSymbols = new List<INamedTypeSymbol>(4);
taskLikeSymbols.AddIfNotNull(TaskSymbol);
taskLikeSymbols.AddIfNotNull(TaskOfTSymbol);
taskLikeSymbols.AddIfNotNull(ValueTaskSymbol);
taskLikeSymbols.AddIfNotNull(ValueTaskOfTSymbol);
_taskLikeSymbols = taskLikeSymbols.ToArray();
}

public INamedTypeSymbol? TaskSymbol { get; set; }
Expand All @@ -65,17 +60,18 @@ public void AnalyzeReturn(OperationAnalysisContext context)
if (returnedValue is null)
return;

if (IsTaskLike(returnedValue.Type))
{
// Must be in a using block
if (!IsInUsingOperation(op))
return;
var returnType = returnedValue.UnwrapImplicitConversionOperations().Type;
if (!_awaitableTypes.IsAwaitable(returnType, returnedValue.SemanticModel!, returnedValue.Syntax.GetLocation().SourceSpan.End))
return;

if (!NeedAwait(returnedValue))
return;
// Must be in a using block
if (!IsInUsingOperation(op))
return;

context.ReportDiagnostic(s_rule, op);
}
if (!NeedAwait(returnedValue))
return;

context.ReportDiagnostic(s_rule, op);
}

private static bool IsInUsingOperation(IOperation operation)
Expand All @@ -92,21 +88,6 @@ private static bool IsInUsingOperation(IOperation operation)
return false;
}

private bool IsTaskLike(ITypeSymbol? symbol)
{
if (symbol is null)
return false;

var originalDefinition = symbol.OriginalDefinition;
foreach (var taskLikeSymbol in _taskLikeSymbols)
{
if (originalDefinition.IsEqualTo(taskLikeSymbol))
return true;
}

return false;
}

private bool NeedAwait(IOperation operation)
{
while (operation is IConversionOperation conversion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,79 @@ await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task NotAwaitedTaskYieldMethod_InUsing()
{
var originalCode = @"
using System;
using System.Threading.Tasks;
class TestClass
{
object Test()
{
using ((IDisposable)null)
{
// Custom awaitable type (not Task/ValueTask)
[||]return Task.Yield();
}
}
}";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task NotAwaitedExtensionMethodOnInt32_InUsing()
{
var originalCode = @"
using System;
using System.Threading.Tasks;
static class TestClass
{
static object Test()
{
using ((IDisposable)null)
{
// It should detect the extension method
[||]return 1;
}
}
static System.Runtime.CompilerServices.TaskAwaiter GetAwaiter(this int value) => throw null;
}";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task NotAwaitedExtensionMethodOnValueTuple_InUsing()
{
var originalCode = @"
using System;
using System.Threading.Tasks;
static class TestClass
{
static object Test()
{
using ((IDisposable)null)
{
// It should detect the extension method
[||]return (default(Task<int>), default(Task<string>));
}
}
static System.Runtime.CompilerServices.TaskAwaiter<(T1, T2)> GetAwaiter<T1, T2>(this (Task<T1>, Task<T2>) tasks) => throw null;
}";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task NotAwaitedValueTask_InUsing()
Expand Down

0 comments on commit 0395ec6

Please sign in to comment.