Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment: restrict allowed trees in annotations #21840

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,17 @@ object TreeChecker {
|${mismatch.message}${mismatch.explanation}
|tree = $tree ${tree.className}""".stripMargin
})
checkWellFormedType(tp1)
checkWellFormedType(tp2)

/** Check that the type `tp` is well-formed. Currently this only means
* checking that annotated types have valid annotation arguments.
*/
private def checkWellFormedType(tp: Type)(using Context): Unit =
tp.foreachPart:
case AnnotatedType(underlying, annot) => checkAnnot(annot.tree)
case _ => ()

}

/** Tree checker that can be applied to a local tree. */
Expand Down
70 changes: 64 additions & 6 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1392,12 +1392,21 @@ trait Checking {
if !Inlines.inInlineMethod && !ctx.isInlineContext then
report.error(em"$what can only be used in an inline method", pos)

def checkAnnot(tree: Tree)(using Context): Tree =
tree match
case Ident(tpnme.BOUNDTYPE_ANNOT) =>
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
// `Ident`s, not annotation instances. See `tests/pos/annot-boundtype.scala`.
tree
case _ =>
checkAnnotTree(checkAnnotClass(tree))

/** Check that the class corresponding to this tree is either a Scala or Java annotation.
*
* @return The original tree or an error tree in case `tree` isn't a valid
* annotation or already an error tree.
*/
def checkAnnotClass(tree: Tree)(using Context): Tree =
private def checkAnnotClass(tree: Tree)(using Context): Tree =
if tree.tpe.isError then
return tree
val cls = Annotations.annotClass(tree)
Expand All @@ -1409,8 +1418,8 @@ trait Checking {
errorTree(tree, em"$cls is not a valid Scala annotation: it does not extend `scala.annotation.Annotation`")
else tree

/** Check arguments of compiler-defined annotations */
def checkAnnotArgs(tree: Tree)(using Context): tree.type =
/** Check arguments of annotations */
private def checkAnnotTree(tree: Tree)(using Context): Tree =
val cls = Annotations.annotClass(tree)
tree match
case Apply(tycon, arg :: Nil) if cls == defn.TargetNameAnnot =>
Expand All @@ -1424,8 +1433,57 @@ trait Checking {
arg.tpe.widenTermRefExpr.normalized match
case _: ConstantType => ()
case _ => report.error(em"@${cls.name} requires constant expressions as a parameter", arg.srcPos)
case _ =>
tree
case _ => ()

findInvalidAnnotSubTree(tree) match
case None => tree
case Some(invalidSubTree) =>
errorTree(
EmptyTree,
em"""Expression cannot be used inside an annotation argument.
|Tree: ${invalidSubTree}
|Type: ${invalidSubTree.tpe}""",
invalidSubTree.srcPos
)

private def findInvalidAnnotSubTree(tree: Tree)(using Context): Option[Tree] =
type ValidAnnotTree =
Ident
| Select
| This
| Super
| Apply
| TypeApply
| Literal
| New
| Typed
| NamedArg
| Assign
| Block
| SeqLiteral
| Inlined
| Hole
| Annotated
| EmptyTree.type

val accumulator = new TreeAccumulator[Option[Tree]]:
override def apply(acc: Option[Tree], tree: Tree)(using Context): Option[Tree] =
if acc.isDefined then
acc
else
tree match
case tree if tree.isType => foldOver(acc, tree)
case closureDef(meth) =>
val paramsRes =
meth.paramss.foldLeft(acc): (acc: Option[Tree], params: List[ValDef] | List[TypeDef]) =>
params.foldLeft(acc): (acc: Option[Tree], param: ValDef | TypeDef) =>
foldOver(acc, param)
foldOver(paramsRes, meth.rhs)
case tree: ValidAnnotTree => foldOver(acc, tree)
case _ => Some(tree)

accumulator(None, tree)


/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
Expand Down Expand Up @@ -1674,7 +1732,7 @@ trait NoChecking extends ReChecking {
override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = ()
override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = ()
override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp
override def checkAnnotArgs(tree: Tree)(using Context): tree.type = tree
override def checkAnnot(tree: Tree)(using Context): tree.type = tree
override def checkNoTargetNameConflict(stats: List[Tree])(using Context): Unit = ()
override def checkParentCall(call: Tree, caller: ClassSymbol)(using Context): Unit = ()
override def checkSimpleKinded(tpt: Tree)(using Context): Tree = tpt
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2799,7 +2799,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

def typedAnnotation(annot: untpd.Tree)(using Context): Tree =
checkAnnotClass(checkAnnotArgs(typed(annot)))
checkAnnot(typed(annot))

def registerNowarn(tree: Tree, mdef: untpd.Tree)(using Context): Unit =
val annot = Annotations.Annotation(tree)
Expand Down Expand Up @@ -3330,7 +3330,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
end typedPackageDef

def typedAnnotated(tree: untpd.Annotated, pt: Type)(using Context): Tree = {
val annot1 = checkAnnotClass(typedExpr(tree.annot))
val annot1 = checkAnnot(typedExpr(tree.annot))
val annotCls = Annotations.annotClass(annot1)
if annotCls == defn.NowarnAnnot then
registerNowarn(annot1, tree)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import scala.quoted.*
class Test {
def foo(str: Expr[String])(using Quotes) = '{
@deprecated($str, "")
@deprecated($str, "") // error: expression cannot be used inside an annotation argument
def bar = ???
}
}
5 changes: 0 additions & 5 deletions tests/neg-macros/i7121.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import scala.quoted.*

class annot1[T](x: Expr[T]) extends scala.annotation.Annotation
class annot2[T: Type](x: T) extends scala.annotation.Annotation

class Test()(implicit qtx: Quotes) {
@annot1('{4}) // error
def foo(str: String) = ()

@annot2(4)(using Type.of[Int]) // error
def foo2(str: String) = ()

}
8 changes: 8 additions & 0 deletions tests/neg-macros/i7121b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.quoted.*

class annot1[T](x: Expr[T]) extends scala.annotation.Annotation

class Test()(implicit qtx: Quotes) {
@annot1('{4}) // error: expression cannot be used inside an annotation argument
def foo(str: String) = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ inline def quote[T]: Quoted[T] = ${ quoteImpl[T] }

def quoteImpl[T: Type](using Quotes): Expr[Quoted[T]] = {
val value: Expr[Int] = '{ 42 }
'{ new Quoted[T @Annot($value)] }
'{ new Quoted[T @Annot($value)] } // error: expression cannot be used inside an annotation argument
}
60 changes: 60 additions & 0 deletions tests/neg/annot-invalid.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
-- Error: tests/neg/annot-invalid.scala:9:21 ---------------------------------------------------------------------------
9 | val x1: Int @annot(new Object {}) = 0 // error
| ^^^^^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: final class $anon() extends Object() {}
| Type: Object {...}
-- Error: tests/neg/annot-invalid.scala:10:28 --------------------------------------------------------------------------
10 | val x2: Int @annot({class C}) = 0 // error
| ^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: class C() extends Object() {}
| Type: C
-- Error: tests/neg/annot-invalid.scala:11:26 --------------------------------------------------------------------------
11 | val x3: Int @annot({val y: Int = 2}) = 0 // error
| ^^^^^^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: val y: Int = 2
| Type: (y : Int)
-- Error: tests/neg/annot-invalid.scala:12:26 --------------------------------------------------------------------------
12 | val x4: Int @annot({def f = 2}) = 0 // error
| ^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: def f: Int = 2
| Type: (f : => Int)
-- Error: tests/neg/annot-invalid.scala:14:25 --------------------------------------------------------------------------
14 | val x5: Int @annot('{4}) = 0 // error
| ^
| Expression cannot be used inside an annotation argument.
| Tree: '{4}
| Type: (scala.quoted.Quotes) ?=> scala.quoted.Expr[Int]
-- Error: tests/neg/annot-invalid.scala:16:9 ---------------------------------------------------------------------------
16 | @annot(new Object {}) val y1: Int = 0 // error
| ^^^^^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: final class $anon() extends Object() {}
| Type: Object {...}
-- Error: tests/neg/annot-invalid.scala:17:16 --------------------------------------------------------------------------
17 | @annot({class C}) val y2: Int = 0 // error
| ^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: class C() extends Object() {}
| Type: C
-- Error: tests/neg/annot-invalid.scala:18:14 --------------------------------------------------------------------------
18 | @annot({val y: Int = 2}) val y3: Int = 0 // error
| ^^^^^^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: val y: Int = 2
| Type: (y : Int)
-- Error: tests/neg/annot-invalid.scala:19:14 --------------------------------------------------------------------------
19 | @annot({def f = 2}) val y4: Int = 0 // error
| ^^^^^^^^^
| Expression cannot be used inside an annotation argument.
| Tree: def f: Int = 2
| Type: (f : => Int)
-- Error: tests/neg/annot-invalid.scala:21:13 --------------------------------------------------------------------------
21 | @annot('{4}) val y5: Int = 0 // error
| ^
| Expression cannot be used inside an annotation argument.
| Tree: '{4}
| Type: (scala.quoted.Quotes) ?=> scala.quoted.Expr[Int]
23 changes: 23 additions & 0 deletions tests/neg/annot-invalid.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import scala.quoted.Quotes

class annot[T](arg: T) extends scala.annotation.Annotation

def main =
object O:
def g(x: Int): Int = x

val x1: Int @annot(new Object {}) = 0 // error
val x2: Int @annot({class C}) = 0 // error
val x3: Int @annot({val y: Int = 2}) = 0 // error
val x4: Int @annot({def f = 2}) = 0 // error
def withQuotes(using Quotes) =
val x5: Int @annot('{4}) = 0 // error

@annot(new Object {}) val y1: Int = 0 // error
@annot({class C}) val y2: Int = 0 // error
@annot({val y: Int = 2}) val y3: Int = 0 // error
@annot({def f = 2}) val y4: Int = 0 // error
def withQuotes2(using Quotes) =
@annot('{4}) val y5: Int = 0 // error

()
2 changes: 2 additions & 0 deletions tests/neg/i7740a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class A(a: Any) extends annotation.StaticAnnotation
@A({val x = 0}) trait B // error: expression cannot be used inside an annotation argument
2 changes: 2 additions & 0 deletions tests/neg/i7740b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class A(a: Any) extends annotation.StaticAnnotation
@A({def x = 0}) trait B // error: expression cannot be used inside an annotation argument
4 changes: 4 additions & 0 deletions tests/neg/i9314.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
final class fooAnnot[T](member: T) extends scala.annotation.StaticAnnotation

@fooAnnot(new RecAnnotated {}) // error: expression cannot be used inside an annotation argument
trait RecAnnotated
3 changes: 3 additions & 0 deletions tests/neg/t7426.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class foo(x: Any) extends annotation.StaticAnnotation

@foo(new AnyRef { }) trait A // error: expression cannot be used inside an annotation argument
15 changes: 15 additions & 0 deletions tests/pos/annot-15054.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.annotation.Annotation

class AnAnnotation(function: Int => String) extends Annotation

@AnAnnotation(_.toString)
val a = 1
@AnAnnotation(_.toString.length.toString)
val b = 2

def test =
@AnAnnotation(_.toString)
val a = 1
@AnAnnotation(_.toString.length.toString)
val b = 2
a + b
16 changes: 16 additions & 0 deletions tests/pos/annot-boundtype.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
// `Ident`s, not annotation instances. This is relevant for `Checking.checkAnnot`.
//
// See also:
// - tests/run/t2755.scala
// - tests/neg/i13044.scala

def f(a: Array[?]) =
a match
case x: Array[?] => ()

def f2(t: Tuple) =
t match
case _: (t *: ts) => ()
case _ => ()

48 changes: 48 additions & 0 deletions tests/pos/annot-valid.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
class annot[T](arg: T) extends scala.annotation.Annotation

def main =
val n: Int = 0
def f(x: Any): Unit = ()

object O:
def g(x: Any): Unit = ()

val x1: Int @annot(42) = 0
val x2: Int @annot("hello") = 0
val x3: Int @annot(classOf[Int]) = 0
val x4: Int @annot(Array(1,2)) = 0
val x5: Int @annot(Array(Array(1,2),Array(3,4))) = 0
val x6: Int @annot((1,2)) = 0
val x7: Int @annot((1,2,3)) = 0
val x8: Int @annot(((1,2),3)) = 0
val x9: Int @annot(((1,2),(3,4))) = 0
val x10: Int @annot(Symbol("hello")) = 0
val x11: Int @annot(n + 1) = 0
val x12: Int @annot(f(2)) = 0
val x13: Int @annot(throw new Error()) = 0
val x14: Int @annot(42: Double) = 0
val x15: Int @annot(O.g(2)) = 0
val x16: Int @annot((x: Int) => x) = 0
val x17: Int @annot([T] => (x: T) => x) = 0
val x18: Int @annot(O.g) = 0

@annot(42) val y1: Int = 0
@annot("hello") val y2: Int = 0
@annot(classOf[Int]) val y3: Int = 0
@annot(Array(1,2)) val y4: Int = 0
@annot(Array(Array(1,2),Array(3,4))) val y5: Int = 0
@annot((1,2)) val y6: Int = 0
@annot((1,2,3)) val y7: Int = 0
@annot(((1,2),3)) val y8: Int = 0
@annot(((1,2),(3,4))) val y9: Int = 0
@annot(Symbol("hello")) val y10: Int = 0
@annot(n + 1) val y11: Int = 0
@annot(f(2)) val y12: Int = 0
@annot(throw new Error()) val y13: Int = 0
@annot(42: Double) val y14: Int = 0
@annot(O.g(2)) val y15: Int = 0
@annot((x: Int) => x) val y16: Int = 0
@annot([T] => (x: T) => x) val y17: Int = 0
@annot(O.g) val y18: Int = 0

()
2 changes: 0 additions & 2 deletions tests/pos/i7740a.scala

This file was deleted.

2 changes: 0 additions & 2 deletions tests/pos/i7740b.scala

This file was deleted.

4 changes: 0 additions & 4 deletions tests/pos/i9314.scala

This file was deleted.

3 changes: 0 additions & 3 deletions tests/pos/t7426.scala

This file was deleted.

Loading