Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emit CA2251 when comparing string.Compare with 0 using Equals #6727

Merged
merged 4 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)

var root = await document.GetRequiredSyntaxRootAsync(token).ConfigureAwait(false);
var node = root.FindNode(context.Span, getInnermostNodeForTie: true);
if (semanticModel.GetOperation(node, token) is not IBinaryOperation violation)
var violation = semanticModel.GetOperation(node, token);
if (violation is not (IBinaryOperation or IInvocationOperation))
return;

// Get the replacer that applies to the reported violation.
Expand All @@ -50,9 +51,9 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)

// Local functions

async Task<Document> CreateChangedDocument(CancellationToken token)
async Task<Document> CreateChangedDocument(CancellationToken cancellationToken)
{
var editor = await DocumentEditor.CreateAsync(document, token).ConfigureAwait(false);
var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
var replacementNode = replacer.CreateReplacementExpression(violation, editor.Generator);
editor.ReplaceNode(violation.Syntax, replacementNode);

Expand Down Expand Up @@ -87,38 +88,55 @@ protected OperationReplacer(RequiredSymbols symbols)
/// </summary>
/// <param name="violation">The <see cref="IBinaryOperation"/> at the location reported by the analyzer.</param>
/// <returns>True if the current <see cref="OperationReplacer"/> applies to the specified violation.</returns>
public abstract bool IsMatch(IBinaryOperation violation);
public abstract bool IsMatch(IOperation violation);

/// <summary>
/// Creates a replacement node for a violation that the current <see cref="OperationReplacer"/> applies to.
/// Asserts if the current <see cref="OperationReplacer"/> does not apply to the specified violation.
/// </summary>
/// <param name="violation">The <see cref="IBinaryOperation"/> obtained at the location reported by the analyzer.
/// <see cref="IsMatch(IBinaryOperation)"/> must return <see langword="true"/> for this operation.</param>
/// <see cref="IsMatch(IOperation)"/> must return <see langword="true"/> for this operation.</param>
/// <param name="generator"></param>
/// <returns></returns>
public abstract SyntaxNode CreateReplacementExpression(IBinaryOperation violation, SyntaxGenerator generator);
public abstract SyntaxNode CreateReplacementExpression(IOperation violation, SyntaxGenerator generator);

protected SyntaxNode CreateEqualsMemberAccess(SyntaxGenerator generator)
{
var stringTypeExpression = generator.TypeExpressionForStaticMemberAccess(Symbols.StringType);
return generator.MemberAccessExpression(stringTypeExpression, nameof(string.Equals));
}

protected static IInvocationOperation GetInvocation(IBinaryOperation violation)
protected IInvocationOperation GetInvocation(IOperation violation)
{
var result = UseStringEqualsOverStringCompare.GetInvocationFromEqualityCheckWithLiteralZero(violation);
var result = violation switch
{
IBinaryOperation b => UseStringEqualsOverStringCompare.GetInvocationFromEqualityCheckWithLiteralZero(b),
IInvocationOperation i => UseStringEqualsOverStringCompare.GetInvocationFromEqualsCheckWithLiteralZero(i, Symbols.IntEquals),
_ => throw new NotSupportedException()
};

RoslynDebug.Assert(result is not null);

return result;
}

protected static SyntaxNode InvertIfNotEquals(SyntaxNode stringEqualsInvocationExpression, IBinaryOperation equalsOrNotEqualsOperation, SyntaxGenerator generator)
protected static SyntaxNode InvertIfNotEquals(SyntaxNode stringEqualsInvocationExpression, IOperation equalsOrNotEqualsOperation, SyntaxGenerator generator)
{
return equalsOrNotEqualsOperation.OperatorKind is BinaryOperatorKind.NotEquals ?
generator.LogicalNotExpression(stringEqualsInvocationExpression) :
stringEqualsInvocationExpression;
if (equalsOrNotEqualsOperation is IBinaryOperation b)
{
return b.OperatorKind is BinaryOperatorKind.NotEquals
? generator.LogicalNotExpression(stringEqualsInvocationExpression)
: stringEqualsInvocationExpression;
}

if (equalsOrNotEqualsOperation is IInvocationOperation i)
{
return i.Instance?.Parent is IUnaryOperation { OperatorKind: UnaryOperatorKind.Not }
? generator.LogicalNotExpression(stringEqualsInvocationExpression)
: stringEqualsInvocationExpression;
}

throw new NotSupportedException();
}
}

Expand All @@ -131,9 +149,9 @@ public StringStringCaseReplacer(RequiredSymbols symbols)
: base(symbols)
{ }

public override bool IsMatch(IBinaryOperation violation) => UseStringEqualsOverStringCompare.IsStringStringCase(violation, Symbols);
public override bool IsMatch(IOperation violation) => UseStringEqualsOverStringCompare.IsStringStringCase(violation, Symbols);

public override SyntaxNode CreateReplacementExpression(IBinaryOperation violation, SyntaxGenerator generator)
public override SyntaxNode CreateReplacementExpression(IOperation violation, SyntaxGenerator generator)
{
RoslynDebug.Assert(IsMatch(violation));

Expand All @@ -155,9 +173,9 @@ public StringStringBoolReplacer(RequiredSymbols symbols)
: base(symbols)
{ }

public override bool IsMatch(IBinaryOperation violation) => UseStringEqualsOverStringCompare.IsStringStringBoolCase(violation, Symbols);
public override bool IsMatch(IOperation violation) => UseStringEqualsOverStringCompare.IsStringStringBoolCase(violation, Symbols);

public override SyntaxNode CreateReplacementExpression(IBinaryOperation violation, SyntaxGenerator generator)
public override SyntaxNode CreateReplacementExpression(IOperation violation, SyntaxGenerator generator)
{
RoslynDebug.Assert(IsMatch(violation));

Expand Down Expand Up @@ -197,9 +215,9 @@ public StringStringStringComparisonReplacer(RequiredSymbols symbols)
: base(symbols)
{ }

public override bool IsMatch(IBinaryOperation violation) => UseStringEqualsOverStringCompare.IsStringStringStringComparisonCase(violation, Symbols);
public override bool IsMatch(IOperation violation) => UseStringEqualsOverStringCompare.IsStringStringStringComparisonCase(violation, Symbols);

public override SyntaxNode CreateReplacementExpression(IBinaryOperation violation, SyntaxGenerator generator)
public override SyntaxNode CreateReplacementExpression(IOperation violation, SyntaxGenerator generator)
{
RoslynDebug.Assert(IsMatch(violation));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,18 @@ private static void OnCompilationStart(CompilationStartAnalysisContext context)
{
if (!RequiredSymbols.TryGetSymbols(context.Compilation, out var symbols))
return;
context.RegisterOperationAction(AnalyzeOperation, OperationKind.Binary);
context.RegisterOperationAction(AnalyzeOperation, OperationKind.Binary, OperationKind.Invocation);
return;

// Local functions

void AnalyzeOperation(OperationAnalysisContext context)
{
var operation = (IBinaryOperation)context.Operation;
foreach (var selector in CaseSelectors)
{
if (selector(operation, symbols))
if (selector(context.Operation, symbols))
{
context.ReportDiagnostic(operation.CreateDiagnostic(Rule));
context.ReportDiagnostic(context.Operation.CreateDiagnostic(Rule));
return;
}
}
Expand All @@ -81,7 +80,8 @@ private RequiredSymbols(
IMethodSymbol? compareStringStringBool,
IMethodSymbol? compareStringStringStringComparison,
IMethodSymbol? equalsStringString,
IMethodSymbol? equalsStringStringStringComparison)
IMethodSymbol? equalsStringStringStringComparison,
IMethodSymbol intEquals)
{
StringType = stringType;
BoolType = boolType;
Expand All @@ -91,6 +91,7 @@ private RequiredSymbols(
CompareStringStringStringComparison = compareStringStringStringComparison;
EqualsStringString = equalsStringString;
EqualsStringStringStringComparison = equalsStringStringStringComparison;
IntEquals = intEquals;
}

public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] out RequiredSymbols? symbols)
Expand All @@ -103,7 +104,8 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
if (stringType is null || boolType is null)
return false;

if (!compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemStringComparison, out var stringComparisonType))
var typeProvider = WellKnownTypeProvider.GetOrCreate(compilation);
if (!typeProvider.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemStringComparison, out var stringComparisonType))
return false;

var compareMethods = stringType.GetMembers(nameof(string.Compare))
Expand All @@ -118,6 +120,15 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
.Where(x => x.IsStatic);
var equalsStringString = equalsMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType);
var equalsStringStringStringComparison = equalsMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType, stringComparisonType);
var intType = typeProvider.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemInt32);
var intEquals = intType
?.GetMembers(nameof(int.Equals))
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.GetParameters() is [var param] && param.Type.Equals(intType, SymbolEqualityComparer.Default));
if (intEquals is null)
{
return false;
}

// Bail if we do not have at least one complete pair of Compare-Equals methods in the compilation.
if ((compareStringString is null || equalsStringString is null) &&
Expand All @@ -130,7 +141,7 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
symbols = new RequiredSymbols(
stringType, boolType, stringComparisonType,
compareStringString, compareStringStringBool, compareStringStringStringComparison,
equalsStringString, equalsStringStringStringComparison);
equalsStringString, equalsStringStringStringComparison, intEquals);
return true;
}

Expand All @@ -142,6 +153,7 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
public IMethodSymbol? CompareStringStringStringComparison { get; }
public IMethodSymbol? EqualsStringString { get; }
public IMethodSymbol? EqualsStringStringStringComparison { get; }
public IMethodSymbol IntEquals { get; }
}

/// <summary>
Expand All @@ -156,9 +168,9 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
/// </summary>
/// <param name="binaryOperation"></param>
/// <returns></returns>
internal static IInvocationOperation? GetInvocationFromEqualityCheckWithLiteralZero(IBinaryOperation binaryOperation)
internal static IInvocationOperation? GetInvocationFromEqualityCheckWithLiteralZero(IBinaryOperation? binaryOperation)
{
if (binaryOperation.OperatorKind is not (BinaryOperatorKind.Equals or BinaryOperatorKind.NotEquals))
if (binaryOperation?.OperatorKind is not (BinaryOperatorKind.Equals or BinaryOperatorKind.NotEquals))
return default;

if (IsLiteralZero(binaryOperation.LeftOperand))
Expand All @@ -176,6 +188,21 @@ static bool IsLiteralZero(IOperation? operation)
}
}

internal static IInvocationOperation? GetInvocationFromEqualsCheckWithLiteralZero(IInvocationOperation? invocation, IMethodSymbol int32Equals)
{
if (!int32Equals.Equals(invocation?.TargetMethod.OriginalDefinition, SymbolEqualityComparer.Default))
{
return default;
}

if (invocation!.Arguments.FirstOrDefault()?.Value is ILiteralOperation { ConstantValue.Value: 0 })
{
return invocation.Instance as IInvocationOperation;
}

return default;
}

/// <summary>
/// Returns true if the specified <see cref="IBinaryOperation"/>:
/// <list type="bullet">
Expand All @@ -184,20 +211,21 @@ static bool IsLiteralZero(IOperation? operation)
/// <item>The other operand is any invocation of <see cref="string.Compare(string, string)"/></item>
/// </list>
/// </summary>
/// <param name="binaryOperation"></param>
/// <param name="operation"></param>
/// <param name="symbols"></param>
/// <returns></returns>
internal static bool IsStringStringCase(IBinaryOperation binaryOperation, RequiredSymbols symbols)
internal static bool IsStringStringCase(IOperation operation, RequiredSymbols symbols)
{
// Don't report a diagnostic if either the string.Compare overload or the
// corrasponding string.Equals overload is missing.
// corresponding string.Equals overload is missing.
if (symbols.CompareStringString is null ||
symbols.EqualsStringString is null)
{
return false;
}

var invocation = GetInvocationFromEqualityCheckWithLiteralZero(binaryOperation);
var invocation = GetInvocationFromEqualityCheckWithLiteralZero(operation as IBinaryOperation)
?? GetInvocationFromEqualsCheckWithLiteralZero(operation as IInvocationOperation, symbols.IntEquals);

return invocation is not null &&
invocation.TargetMethod.Equals(symbols.CompareStringString, SymbolEqualityComparer.Default);
Expand All @@ -212,20 +240,21 @@ internal static bool IsStringStringCase(IBinaryOperation binaryOperation, Requir
/// <item>The <c>ignoreCase</c> argument is a boolean literal</item>
/// </list>
/// </summary>
/// <param name="binaryOperation"></param>
/// <param name="operation"></param>
/// <param name="symbols"></param>
/// <returns></returns>
internal static bool IsStringStringBoolCase(IBinaryOperation binaryOperation, RequiredSymbols symbols)
internal static bool IsStringStringBoolCase(IOperation operation, RequiredSymbols symbols)
{
// Don't report a diagnostic if either the string.Compare overload or the
// corrasponding string.Equals overload is missing.
// corresponding string.Equals overload is missing.
if (symbols.CompareStringStringBool is null ||
symbols.EqualsStringStringStringComparison is null)
{
return false;
}

var invocation = GetInvocationFromEqualityCheckWithLiteralZero(binaryOperation);
var invocation = GetInvocationFromEqualityCheckWithLiteralZero(operation as IBinaryOperation)
?? GetInvocationFromEqualsCheckWithLiteralZero(operation as IInvocationOperation, symbols.IntEquals);

// Only report a diagnostic if the 'ignoreCase' argument is a boolean literal.
return invocation is not null &&
Expand All @@ -242,10 +271,10 @@ internal static bool IsStringStringBoolCase(IBinaryOperation binaryOperation, Re
/// <item>The other operand is any invocation of <see cref="string.Compare(string, string, StringComparison)"/></item>
/// </list>
/// </summary>
/// <param name="binaryOperation"></param>
/// <param name="operation"></param>
/// <param name="symbols"></param>
/// <returns></returns>
internal static bool IsStringStringStringComparisonCase(IBinaryOperation binaryOperation, RequiredSymbols symbols)
internal static bool IsStringStringStringComparisonCase(IOperation operation, RequiredSymbols symbols)
{
// Don't report a diagnostic if either the string.Compare overload or the
// corrasponding string.Equals overload is missing.
Expand All @@ -255,17 +284,18 @@ internal static bool IsStringStringStringComparisonCase(IBinaryOperation binaryO
return false;
}

var invocation = GetInvocationFromEqualityCheckWithLiteralZero(binaryOperation);
var invocation = GetInvocationFromEqualityCheckWithLiteralZero(operation as IBinaryOperation)
?? GetInvocationFromEqualsCheckWithLiteralZero(operation as IInvocationOperation, symbols.IntEquals);

return invocation is not null &&
invocation.TargetMethod.Equals(symbols.CompareStringStringStringComparison, SymbolEqualityComparer.Default);
}

// No IOperation instances are being stored here.
#pragma warning disable RS1008 // Avoid storing per-compilation data into the fields of a diagnostic analyzer
private static readonly ImmutableArray<Func<IBinaryOperation, RequiredSymbols, bool>> CaseSelectors =
private static readonly ImmutableArray<Func<IOperation, RequiredSymbols, bool>> CaseSelectors =
#pragma warning restore RS1008 // Avoid storing per-compilation data into the fields of a diagnostic analyzer
ImmutableArray.Create<Func<IBinaryOperation, RequiredSymbols, bool>>(
ImmutableArray.Create<Func<IOperation, RequiredSymbols, bool>>(
IsStringStringCase,
IsStringStringBoolCase,
IsStringStringStringComparisonCase);
Expand Down
Loading