diff --git a/core/src/main/java/edu/ucr/cs/riple/core/Annotator.java b/core/src/main/java/edu/ucr/cs/riple/core/Annotator.java index 16a771152..690c13831 100644 --- a/core/src/main/java/edu/ucr/cs/riple/core/Annotator.java +++ b/core/src/main/java/edu/ucr/cs/riple/core/Annotator.java @@ -48,6 +48,7 @@ import edu.ucr.cs.riple.injector.changes.AddMarkerAnnotation; import edu.ucr.cs.riple.injector.changes.AddSingleElementAnnotation; import edu.ucr.cs.riple.injector.location.OnField; +import edu.ucr.cs.riple.injector.location.OnMethod; import edu.ucr.cs.riple.scanner.Serializer; import java.util.List; import java.util.Objects; @@ -262,11 +263,25 @@ private void forceResolveRemainingErrors( // the method level. Set nullUnMarkedAnnotations = remainingErrors.stream() - // filter non-method regions. - .filter(error -> !error.encMethod().equals("null")) // find the corresponding method nodes. - .map(error -> tree.findNode(error.encMethod(), error.encClass())) - // impossible, just sanity check or future nullness checker hints + .map( + error -> { + if (!error.encMethod().equals("null")) { + return tree.findNode(error.encMethod(), error.encClass()); + } + if (error.nonnullTarget == null) { + // Just a sanity check. + return null; + } + // For methods invoked in an initialization region, where the error is that + // `@Nullable` is being passed as an argument, we add a `@NullUnmarked` annotation + // to the called method. + if (error.messageType.equals("PASS_NULLABLE")) { + OnMethod calledMethod = error.nonnullTarget.toMethod(); + return tree.findNode(calledMethod.method, calledMethod.clazz); + } + return null; + }) .filter(Objects::nonNull) .map(node -> new AddMarkerAnnotation(node.location, config.nullUnMarkedAnnotation)) .collect(Collectors.toSet()); diff --git a/core/src/test/java/edu/ucr/cs/riple/core/CoreTest.java b/core/src/test/java/edu/ucr/cs/riple/core/CoreTest.java index d37c03963..f694586cf 100644 --- a/core/src/test/java/edu/ucr/cs/riple/core/CoreTest.java +++ b/core/src/test/java/edu/ucr/cs/riple/core/CoreTest.java @@ -128,7 +128,7 @@ public void field_assign_nullable() { } @Test - public void field_assign_nullable_constructor() { + public void fieldAssignNullableConstructor() { coreTestHelper .addInputLines( "Main.java", @@ -148,6 +148,28 @@ public void field_assign_nullable_constructor() { .start(); } + @Test + public void fieldAssignNullableConstructorForceResolveEnabled() { + coreTestHelper + .addInputLines( + "Main.java", + "package test;", + "public class Main {", + " Object f;", + " Main(Object f) {", + " this.f = f;", + " }", + "}", + "class C {", + " Main main = new Main(null);", + "}") + .toDepth(1) + .addExpectedReports( + new TReport(new OnParameter("Main.java", "test.Main", "Main(java.lang.Object)", 0), 1)) + .enableForceResolve() + .start(); + } + @Test public void multipleFieldDeclarationTest() { coreTestHelper