Skip to content

Commit

Permalink
SE - Nullable: Add support for nullable.GetValueOrDefault() (#6929)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-mikula-sonarsource committed Mar 20, 2023
1 parent 44b81b6 commit 65df768
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/

using SonarAnalyzer.SymbolicExecution.Constraints;
using SonarAnalyzer.SymbolicExecution.Roslyn.Checks;

namespace SonarAnalyzer.SymbolicExecution.Roslyn.OperationProcessors;

Expand All @@ -36,7 +37,8 @@ protected override ProgramState[] Process(SymbolicContext context, IInvocationOp
var state = context.State;
if (!invocation.TargetMethod.IsStatic // Also applies to C# extensions
&& !invocation.TargetMethod.IsExtensionMethod // VB extensions in modules are not marked as static
&& invocation.Instance.TrackedSymbol() is { } symbol)
&& invocation.Instance.TrackedSymbol() is { } symbol
&& !IsNullableGetValueOrDefault(invocation))
{
state = state.SetSymbolConstraint(symbol, ObjectConstraint.NotNull);
}
Expand All @@ -50,6 +52,7 @@ protected override ProgramState[] Process(SymbolicContext context, IInvocationOp
}
return invocation switch
{
_ when IsNullableGetValueOrDefault(invocation) => ProcessNullableGetValueOrDefault(context, invocation).ToArray(),
_ when invocation.TargetMethod.Is(KnownType.Microsoft_VisualBasic_Information, "IsNothing") => ProcessInformationIsNothing(context, invocation),
_ when invocation.TargetMethod.Is(KnownType.System_Diagnostics_Debug, nameof(Debug.Assert)) => ProcessDebugAssert(context, invocation),
_ when invocation.TargetMethod.ContainingType.IsAny(KnownType.System_Linq_Enumerable, KnownType.System_Linq_Queryable) => ProcessLinqEnumerableAndQueryable(context, invocation),
Expand Down Expand Up @@ -196,6 +199,24 @@ private static ProgramState[] ProcessEquals(SymbolicContext context, IInvocation
return context.State.ToArray();
}

private static ProgramState ProcessNullableGetValueOrDefault(SymbolicContext context, IInvocationOperationWrapper invocation)
{
return context.State[invocation.Instance] switch
{
{ } instanceValue when instanceValue.HasConstraint(ObjectConstraint.Null) => NullableDefaultState(),
{ } instanceValue => context.State.SetOperationValue(invocation, instanceValue),
_ => context.State
};

ProgramState NullableDefaultState()
{
var valueType = ((INamedTypeSymbol)invocation.Instance.Type).TypeArguments.Single();
return ConstantCheck.ConstraintFromType(valueType) is { } orDefaultConstraint
? context.SetOperationConstraint(orDefaultConstraint)
: context.State;
}
}

private static bool IsThrowHelper(IMethodSymbol method) =>
method.Is(KnownType.System_Diagnostics_Debug, nameof(Debug.Fail))
|| method.IsAny(KnownType.System_Environment, nameof(Environment.FailFast), nameof(Environment.Exit))
Expand All @@ -216,4 +237,7 @@ _ when invocation.Arguments[0].TrackedSymbol() is { } argumentSymbol => new[]
},
_ => context.State.ToArray()
};

private static bool IsNullableGetValueOrDefault(IInvocationOperationWrapper invocation) =>
invocation.TargetMethod.Is(KnownType.System_Nullable_T, nameof(Nullable<int>.GetValueOrDefault));
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,100 @@ public void Nullable_Conversion_PropagateConstraints()
validator.ValidateTag("ToNullableAs", x => x.HasConstraint(BoolConstraint.True).Should().BeTrue());
validator.ValidateTag("ToBoolExplicit", x => x.HasConstraint(BoolConstraint.True).Should().BeTrue());
}

[TestMethod]
public void Nullable_GetValueOrDefault_Int()
{
const string code = """
var value = arg.GetValueOrDefault();
Tag("UnknownArg", arg);
Tag("UnknownValue", value);
arg = null; // Adds DummyConstraint
value = arg.GetValueOrDefault();
Tag("NullArg", arg);
Tag("NullValue", value);
arg = 42; // Adds DummyConstraint
value = arg.GetValueOrDefault();
Tag("NotNullArg", arg);
Tag("NotNullValue", value);
""";
var validator = SETestContext.CreateCS(code, ", int? arg", new LiteralDummyTestCheck()).Validator;
validator.ValidateTag("UnknownArg", x => x.Should().HaveNoConstraints());
validator.ValidateTag("UnknownValue", x => x.Should().HaveOnlyConstraint(ObjectConstraint.NotNull));
validator.ValidateTag("NullArg", x => x.Should().HaveOnlyConstraints(ObjectConstraint.Null, DummyConstraint.Dummy));
validator.ValidateTag("NullValue", x => x.Should().HaveOnlyConstraint(ObjectConstraint.NotNull));
validator.ValidateTag("NotNullArg", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, DummyConstraint.Dummy));
validator.ValidateTag("NotNullValue", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, DummyConstraint.Dummy));
}

[TestMethod]
public void Nullable_GetValueOrDefault_Bool()
{
const string code = """
var value = arg.GetValueOrDefault();
Tag("UnknownArg", arg);
Tag("UnknownValue", value);
arg = null; // Adds DummyConstraint
value = arg.GetValueOrDefault();
Tag("NullArg", arg);
Tag("NullValue", value);
arg = true; // Adds DummyConstraint
value = arg.GetValueOrDefault();
Tag("NotNullArg", arg);
Tag("NotNullValue", value);
""";
var validator = SETestContext.CreateCS(code, ", bool? arg", new LiteralDummyTestCheck()).Validator;
validator.ValidateTag("UnknownArg", x => x.Should().HaveNoConstraints());
validator.ValidateTag("UnknownValue", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull));
validator.ValidateTag("NullArg", x => x.Should().HaveOnlyConstraints(ObjectConstraint.Null, DummyConstraint.Dummy));
validator.ValidateTag("NullValue", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.False));
validator.ValidateTag("NotNullArg", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.True, DummyConstraint.Dummy));
validator.ValidateTag("NotNullValue", x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.True, DummyConstraint.Dummy));
}

[TestMethod]
public void Nullable_GetValueOrDefault_SubExpression()
{
const string code = """
var value = (Condition ? null : (bool?)true).GetValueOrDefault();
Tag("Value", value);
""";
SETestContext.CreateCS(code).Validator.TagValues("Value").Should().SatisfyRespectively(
x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.False),
x => x.Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.True));
}

[TestMethod]
public void Nullable_GetValueOrDefault_SubExpression_Branching()
{
const string code = """
bool? nullable;
if (boolParameter)
nullable = true;
else
nullable = null;
var value = nullable.GetValueOrDefault();
Tag("End");
""";
var validator = SETestContext.CreateCS(code, new PreserveTestCheck("boolParameter", "nullable", "value")).Validator;
var boolParameter = validator.Symbol("boolParameter");
var nullable = validator.Symbol("nullable");
var value = validator.Symbol("value");
validator.TagStates("End").Should().SatisfyRespectively(
x =>
{
x[boolParameter].Should().HaveOnlyConstraints(BoolConstraint.True); // NotNull is missing
x[nullable].Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.True);
x[value].Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.True);
}, x =>
{
x[boolParameter].Should().HaveOnlyConstraints(BoolConstraint.False); // NotNull is missing
x[nullable].Should().HaveOnlyConstraint(ObjectConstraint.Null);
x[value].Should().HaveOnlyConstraints(ObjectConstraint.NotNull, BoolConstraint.False);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@ namespace SonarAnalyzer.UnitTest.TestFramework.SymbolicExecution
{
public class PreserveTestCheck : SymbolicCheck
{
private readonly string symbolName;
private readonly HashSet<string> symbolNames;

public PreserveTestCheck(string symbolName) =>
this.symbolName = symbolName;
public PreserveTestCheck(params string[] symbolNames)
{
if (symbolNames.Length == 0)
{
throw new ArgumentException("Value cannot be empty", nameof(symbolNames));
}
this.symbolNames = new(symbolNames);
}

protected override ProgramState PreProcessSimple(SymbolicContext context) =>
context.Operation.Instance.TrackedSymbol() is { } symbol && symbol.Name == symbolName
context.Operation.Instance.TrackedSymbol() is { } symbol && symbolNames.Contains(symbol.Name)
? context.State.Preserve(symbol)
: context.State;
}
Expand Down

0 comments on commit 65df768

Please sign in to comment.