diff --git a/nullaway/src/main/java/com/uber/nullaway/NullAway.java b/nullaway/src/main/java/com/uber/nullaway/NullAway.java index b9d42df44c..a478fc5cd3 100644 --- a/nullaway/src/main/java/com/uber/nullaway/NullAway.java +++ b/nullaway/src/main/java/com/uber/nullaway/NullAway.java @@ -277,6 +277,8 @@ public Config getConfig() { */ private final Map computedNullnessMap = new LinkedHashMap<>(); + private GenericsChecks genericsChecks = new GenericsChecks(); + /** * Error Prone requires us to have an empty constructor for each Plugin, in addition to the * constructor taking an ErrorProneFlags object. This constructor should not be used anywhere @@ -500,7 +502,7 @@ public Description matchAssignment(AssignmentTree tree, VisitorState state) { } // generics check if (lhsType != null && config.isJSpecifyMode()) { - GenericsChecks.checkTypeParameterNullnessForAssignability(tree, this, state); + genericsChecks.checkTypeParameterNullnessForAssignability(tree, this, state); } if (config.isJSpecifyMode() && tree.getVariable() instanceof ArrayAccessTree) { @@ -1494,7 +1496,7 @@ public Description matchVariable(VariableTree tree, VisitorState state) { } VarSymbol symbol = ASTHelpers.getSymbol(tree); if (tree.getInitializer() != null && config.isJSpecifyMode()) { - GenericsChecks.checkTypeParameterNullnessForAssignability(tree, this, state); + genericsChecks.checkTypeParameterNullnessForAssignability(tree, this, state); } if (!config.isLegacyAnnotationLocation()) { checkNullableAnnotationPositionInType( @@ -1662,6 +1664,7 @@ public Description matchClass(ClassTree tree, VisitorState state) { class2Entities.clear(); class2ConstructorUninit.clear(); computedNullnessMap.clear(); + genericsChecks.clearCache(); EnclosingEnvironmentNullness.instance(state.context).clear(); } else if (classAnnotationIntroducesPartialMarking(classSymbol)) { // Handle the case where the top-class is unannotated, but there is a @NullMarked annotation @@ -1880,7 +1883,7 @@ private Description handleInvocation( Nullness.paramHasNullableAnnotation(methodSymbol, i, config) ? Nullness.NULLABLE : ((config.isJSpecifyMode() && tree instanceof MethodInvocationTree) - ? GenericsChecks.getGenericParameterNullnessAtInvocation( + ? genericsChecks.getGenericParameterNullnessAtInvocation( i, methodSymbol, (MethodInvocationTree) tree, state, config) : Nullness.NONNULL); } diff --git a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java index 3318763a44..31149da0b4 100644 --- a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java +++ b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import javax.lang.model.type.ExecutableType; import javax.lang.model.type.TypeKind; import javax.lang.model.type.TypeVariable; @@ -46,8 +47,12 @@ /** Methods for performing checks related to generic types and nullability. */ public final class GenericsChecks { - /** Do not instantiate; all methods should be static */ - private GenericsChecks() {} + /** + * Maps a MethodInvocationTree to a set of type variables that are mapped to their inferred types. + * Any generic type parameter that are not explicitly stated are inferred and cached in this + * field. + */ + private final Map> inferredTypes = new HashMap<>(); /** * Checks that for an instantiated generic type, {@code @Nullable} types are only used for type @@ -413,13 +418,16 @@ private static void reportInvalidOverridingMethodParamTypeError( * @param analysis the analysis object * @param state the visitor state */ - public static void checkTypeParameterNullnessForAssignability( + public void checkTypeParameterNullnessForAssignability( Tree tree, NullAway analysis, VisitorState state) { Config config = analysis.getConfig(); if (!config.isJSpecifyMode()) { return; } Type lhsType = getTreeType(tree, config); + if (lhsType == null) { + return; + } Tree rhsTree; if (tree instanceof VariableTree) { VariableTree varTree = (VariableTree) tree; @@ -428,6 +436,22 @@ public static void checkTypeParameterNullnessForAssignability( AssignmentTree assignmentTree = (AssignmentTree) tree; rhsTree = assignmentTree.getExpression(); } + + if (rhsTree instanceof MethodInvocationTree) { + MethodInvocationTree methodInvocationTree = (MethodInvocationTree) rhsTree; + Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(methodInvocationTree); + // update inferredTypes cache for assignments + // generic method call with no explicit generic arguments + if (methodSymbol.type instanceof Type.ForAll + && methodInvocationTree.getTypeArguments().isEmpty()) { + Type returnType = methodSymbol.getReturnType(); + Map genericNullness = + returnType.accept(new InferTypeVisitor(config), lhsType); + if (genericNullness != null) { + inferredTypes.put(methodInvocationTree, genericNullness); + } + } + } // rhsTree can be null for a VariableTree. Also, we don't need to do a check // if rhsTree is the null literal if (rhsTree == null || rhsTree.getKind().equals(Tree.Kind.NULL_LITERAL)) { @@ -435,7 +459,26 @@ public static void checkTypeParameterNullnessForAssignability( } Type rhsType = getTreeType(rhsTree, config); - if (lhsType != null && rhsType != null) { + if (rhsTree instanceof MethodInvocationTree) { + // recreate rhsType using inferred types + MethodInvocationTree methodInvocationTree = (MethodInvocationTree) rhsTree; + Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(methodInvocationTree); + if (inferredTypes.containsKey(methodInvocationTree)) { + Map genericNullness = inferredTypes.get(methodInvocationTree); + List keyTypeList = + genericNullness.keySet().stream() + .map(typeVar -> (Type) typeVar) + .collect(Collectors.toList()); + com.sun.tools.javac.util.List from = com.sun.tools.javac.util.List.from(keyTypeList); + com.sun.tools.javac.util.List to = + com.sun.tools.javac.util.List.from(genericNullness.values()); + rhsType = + TypeSubstitutionUtils.subst( + state.getTypes(), methodSymbol.getReturnType(), from, to, config); + } + } + + if (rhsType != null) { boolean isAssignmentValid = subtypeParameterNullability(lhsType, rhsType, state, config); if (!isAssignmentValid) { reportInvalidAssignmentInstantiationError(tree, lhsType, rhsType, state, analysis); @@ -929,7 +972,7 @@ private static Type substituteTypeArgsInGenericMethodType( * @return Nullness of parameter at {@code paramIndex}, or {@code NONNULL} if the call does not * invoke an instance method */ - public static Nullness getGenericParameterNullnessAtInvocation( + public Nullness getGenericParameterNullnessAtInvocation( int paramIndex, Symbol.MethodSymbol invokedMethodSymbol, MethodInvocationTree tree, @@ -949,6 +992,21 @@ public static Nullness getGenericParameterNullnessAtInvocation( getTypeNullness(substitutedParamTypes.get(paramIndex), config), Nullness.NULLABLE)) { return Nullness.NULLABLE; } + // check nullness of inferred types + if (inferredTypes.containsKey(tree)) { + Map genericNullness = inferredTypes.get(tree); + List parameters = invokedMethodSymbol.getParameters(); + if (genericNullness.containsKey(parameters.get(paramIndex).type)) { + Type genericType = parameters.get(paramIndex).type; + Type inferredGenericType = genericNullness.get(genericType); + if (inferredGenericType != null + && Objects.equals(getTypeNullness(inferredGenericType, config), Nullness.NULLABLE)) { + return Nullness.NULLABLE; + } else { + return Nullness.NONNULL; + } + } + } } if (!(tree.getMethodSelect() instanceof MemberSelectTree) || invokedMethodSymbol.isStatic()) { @@ -1157,6 +1215,10 @@ public static boolean passingLambdaOrMethodRefWithGenericReturnToUnmarkedCode( return callingUnannotated; } + public void clearCache() { + inferredTypes.clear(); + } + public static boolean isNullableAnnotated(Type type, Config config) { return Nullness.hasNullableAnnotation(type.getAnnotationMirrors().stream(), config); } diff --git a/nullaway/src/main/java/com/uber/nullaway/generics/InferTypeVisitor.java b/nullaway/src/main/java/com/uber/nullaway/generics/InferTypeVisitor.java new file mode 100644 index 0000000000..479ab943d8 --- /dev/null +++ b/nullaway/src/main/java/com/uber/nullaway/generics/InferTypeVisitor.java @@ -0,0 +1,70 @@ +package com.uber.nullaway.generics; + +import com.sun.tools.javac.code.Type; +import com.sun.tools.javac.code.Types; +import com.uber.nullaway.Config; +import com.uber.nullaway.Nullness; +import java.util.HashMap; +import java.util.Map; +import javax.lang.model.type.TypeVariable; +import org.jspecify.annotations.Nullable; + +/** Visitor that uses two types to infer the type of type variables. */ +public class InferTypeVisitor + extends Types.DefaultTypeVisitor<@Nullable Map, Type> { + private final Config config; + + InferTypeVisitor(Config config) { + this.config = config; + } + + @Override + public @Nullable Map visitClassType(Type.ClassType rhsType, Type lhsType) { + Map genericNullness = new HashMap<>(); + com.sun.tools.javac.util.List rhsTypeArguments = rhsType.getTypeArguments(); + com.sun.tools.javac.util.List lhsTypeArguments = + ((Type.ClassType) lhsType).getTypeArguments(); + // get the inferred type for each type arguments and add them to genericNullness + for (int i = 0; i < rhsTypeArguments.size(); i++) { + Type rhsTypeArg = rhsTypeArguments.get(i); + Type lhsTypeArg = lhsTypeArguments.get(i); + Map map = rhsTypeArg.accept(this, lhsTypeArg); + if (map != null) { + genericNullness.putAll(map); + } + } + return genericNullness.isEmpty() ? null : genericNullness; + } + + @Override + public @Nullable Map visitArrayType(Type.ArrayType rhsType, Type lhsType) { + // unwrap the type of the array and call accept on it + Type rhsComponentType = rhsType.elemtype; + Type lhsComponentType = ((Type.ArrayType) lhsType).elemtype; + Map genericNullness = rhsComponentType.accept(this, lhsComponentType); + return genericNullness; + } + + @Override + public Map visitTypeVar(Type.TypeVar rhsType, Type lhsType) { + Map genericNullness = new HashMap<>(); + Boolean isLhsNullable = + Nullness.hasNullableAnnotation(lhsType.getAnnotationMirrors().stream(), config); + Type upperBound = rhsType.getUpperBound(); + Boolean isRhsNullable = + Nullness.hasNullableAnnotation(upperBound.getAnnotationMirrors().stream(), config); + if (!isLhsNullable) { // lhsType is NonNull, we can just use this + genericNullness.put(rhsType, lhsType); + } else if (isRhsNullable) { // lhsType & rhsType are Nullable, can use lhs for inference + genericNullness.put(rhsType, lhsType); + } else { // rhs can't be nullable, use upperbound + genericNullness.put(rhsType, upperBound); + } + return genericNullness; + } + + @Override + public @Nullable Map visitType(Type t, Type type) { + return null; + } +} diff --git a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java index 6f37df6b8e..aba1a4a025 100644 --- a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java +++ b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericMethodTests.java @@ -161,6 +161,234 @@ public void genericMethodAndVoidTypeWithInference() { .doTest(); } + @Test + public void genericInferenceOnAssignments() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + " class Test {", + " static class Foo {", + " Foo(T t) {}", + " static Foo makeNull(U u) {", + " return new Foo<>(u);", + " }", + " static Foo makeNonNull(U u) {", + " return new Foo<>(u);", + " }", + " }", + " static void testLocalAssign() {", + " // legal", + " Foo<@Nullable Object> f1 = Foo.makeNull(null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo f2 = Foo.makeNull(null);", + " Foo<@Nullable Object> f3 = Foo.makeNull(new Object());", + " Foo f4 = Foo.makeNull(new Object());", + " // ILLEGAL: U does not have a @Nullable upper bound", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo<@Nullable Object> f5 = Foo.makeNonNull(null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo f6 = Foo.makeNonNull(null);", + " // BUG: Diagnostic contains: due to mismatched nullability of type parameters", + " Foo<@Nullable Object> f7 = Foo.makeNonNull(new Object());", + " Foo f8 = Foo.makeNonNull(new Object());", + " }", + " }") + .doTest(); + } + + @Test + public void genericInferenceOnAssignmentAfterDeclaration() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + " class Test {", + " static class Foo {", + " Foo(T t) {}", + " static Foo makeNull(U u) {", + " return new Foo<>(u);", + " }", + " static Foo makeNonNull(U u) {", + " return new Foo<>(u);", + " }", + " }", + " static void testAssignAfterDeclaration() {", + " // legal", + " Foo<@Nullable Object> f1; f1 = Foo.makeNull(null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo f2; f2 = Foo.makeNull(null);", + " }", + " }") + .doTest(); + } + + @Test + public void multipleParametersOneNested() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + "class Test {", + " static class Foo {", + " Foo(T t) {}", + " static Foo create(U u, Foo other) {", + " return new Foo<>(u);", + " }", + " static void test(Foo<@Nullable Object> f1, Foo f2) {", + " // no error expected", + " Foo<@Nullable Object> result = Foo.create(null, f1);", + " // BUG: Diagnostic contains: XXX", + " Foo<@Nullable Object> result2 = Foo.create(null, f2);", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void genericInferenceOnAssignmentsMultipleParams() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + "class Test {", + " class Foo {", + " Foo(T t) {}", + " public Foo make(U u, @Nullable String s) {", + " return new Foo<>(u);", + " }", + " }", + " static class Bar {", + " Bar(S s, Z z) {}", + " static Bar make(U u, B b) {", + " return new Bar<>(u, b);", + " }", + " }", + " static class Baz {", + " Baz(S s, Z z) {}", + " static Baz make(U u, B b) {", + " return new Baz<>(u, b);", + " }", + " }", + " public void run(Foo<@Nullable String> foo) {", + " // legal", + " Foo<@Nullable Object> f1 = foo.make(null, new String());", + " Foo<@Nullable Object> f2 = foo.make(null, null);", + " Foo f3 = foo.make(new Object(), null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo f4 = foo.make(null, null);", + " // legal", + " Bar<@Nullable Object, Object> b1 = Bar.make(null, new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Bar<@Nullable Object, Object> b2 = Bar.make(null, null);", + " Bar<@Nullable Object, @Nullable Object> b3 = Bar.make(null, null);", + " Bar b4 = Bar.make(new Object(), null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Bar b5 = Bar.make(null, null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Bar b6 = Bar.make(null, null);", + " // legal", + " Baz baz1 = Baz.make(new String(), new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Baz baz2 = Baz.make(null, new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Baz baz3 = Baz.make(new String(), null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Baz baz4 = Baz.make(null, null);", + " // BUG: Diagnostic contains: Generic type parameter cannot be @Nullable", + " Baz<@Nullable String, Object> baz5 = Baz.make(new String(), new Object());", + " // BUG: Diagnostic contains: Generic type parameter cannot be @Nullable", + " Baz baz6 = Baz.make(new String(), new Object());", + " }", + "}") + .doTest(); + } + + @Test + public void genericsUsedForGenericClasses() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + "import java.util.ArrayList;", + "class Test {", + " abstract class Foo {", + " abstract Foo> nonNullTest();", + " abstract Foo> nullTest();", + " }", + " static void test(Foo f) {", + " Foo> fooNonNull_1 = f.nonNullTest();", + " // BUG: Diagnostic contains: due to mismatched nullability of type parameters", + " Foo> fooNonNull_2 = f.nonNullTest();", + " // BUG: Diagnostic contains: due to mismatched nullability of type parameters", + " Foo<@Nullable Integer, ArrayList> fooNonNull_3 = f.nonNullTest();", + " Foo> fooNull_1 = f.nullTest();", + " Foo> fooNull_2 = f.nullTest();", + " Foo<@Nullable Integer, ArrayList> fooNull_3 = f.nullTest();", + " }", + "}") + .doTest(); + } + + @Test + public void genericInferenceOnAssignmentsWithArrays() { + makeHelper() + .addSourceLines( + "Test.java", + "package com.uber;", + "import org.jspecify.annotations.Nullable;", + " class Test {", + " static class Foo {", + " Foo(T t) {}", + " static Foo[]> test1Null(U u) {", + " return new Foo<>((Foo[]) new Foo[5]);", + " }", + " static Foo[]> test1Nonnull(U u) {", + " return new Foo<>((Foo[]) new Foo[5]);", + " }", + " static Foo[] test2Null(U u) {", + " return (Foo[]) new Foo[5];", + " }", + " static Foo[] test2Nonnull(U u) {", + " return (Foo[]) new Foo[5];", + " }", + " }", + " static void testLocalAssign() {", + " Foo[]> f1 = Foo.test1Null(new Object());", + " Foo[]> f2 = Foo.test1Null(new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo[]> f3 = Foo.test1Null(null);", + " Foo[]> f4 = Foo.test1Null(null);", + " Foo[]> f5 = Foo.test1Nonnull(new Object());", + " // BUG: Diagnostic contains: due to mismatched nullability of type parameters", + " Foo[]> f6 = Foo.test1Nonnull(new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo[]> f7 = Foo.test1Nonnull(null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo[]> f8 = Foo.test1Nonnull(null);", + " Foo[] f9 = Foo.test2Null(new Object());", + " Foo<@Nullable Object>[] f10 = Foo.test2Null(new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo[] f11 = Foo.test2Null(null);", + " Foo<@Nullable Object>[] f12 = Foo.test2Null(null);", + " Foo[] f13 = Foo.test2Nonnull(new Object());", + " // BUG: Diagnostic contains: due to mismatched nullability of type parameters", + " Foo<@Nullable Object>[] f14 = Foo.test2Nonnull(new Object());", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo[] f15 = Foo.test2Nonnull(null);", + " // BUG: Diagnostic contains: passing @Nullable parameter 'null' where @NonNull is required", + " Foo<@Nullable Object>[] f16 = Foo.test2Nonnull(null);", + " }", + " }") + .doTest(); + } + @Test public void issue1035() { makeHelper()