From 009c2a49e61f5deae419ad1dcae5f3ee46f94956 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Wed, 23 Oct 2024 20:10:42 +0200 Subject: [PATCH] Experiment: restrict allowed trees in type annotations --- .../dotty/tools/dotc/core/Annotations.scala | 50 +++++++++++++++++++ .../src/dotty/tools/dotc/core/Types.scala | 1 + .../annot-forbidden-type-annotations.scala | 16 ++++++ tests/pos/annot-17939b.scala | 10 ---- tests/pos/annotDepMethType.scala | 1 - tests/printing/annot-19846b.check | 33 ------------ tests/printing/annot-19846b.scala | 7 --- 7 files changed, 67 insertions(+), 51 deletions(-) create mode 100644 tests/neg/annot-forbidden-type-annotations.scala delete mode 100644 tests/pos/annot-17939b.scala delete mode 100644 tests/printing/annot-19846b.check delete mode 100644 tests/printing/annot-19846b.scala diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index d6a99b12e3b3..3be72c99c9ac 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -3,6 +3,8 @@ package dotc package core import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.* +import Decorators.i +import StdNames.nme import ast.tpd, tpd.* import util.Spans.Span import printing.{Showable, Printer} @@ -106,6 +108,54 @@ object Annotations { go(metaSyms) || orNoneOf.nonEmpty && !go(orNoneOf) } + /** True if this annotation can be used as a type annotation, false otherwise. + * + * An annotation is a valid type annotation if its tree is one a `Literal`. + * + * Can be overridden. + */ + def checkValidTypeAnnotation()(using Context): Unit = + def isTupleModule(sym: Symbol): Boolean = + ctx.definitions.isTupleClass(sym.companionClass) + + def isFunctionAllowed(t: Tree): Boolean = + t match + case Select(qual, nme.apply) => qual.symbol == defn.ArrayModule || isTupleModule(qual.symbol) + case TypeApply(fun, _) => isFunctionAllowed(fun) + case _ => false + + def check(t: Tree): Boolean = + t match + case Literal(_) => true + case Typed(expr, _) => check(expr) + case SeqLiteral(elems, _) => elems.forall(check) + case Apply(fun, args) => isFunctionAllowed(fun) && args.forall(check) + case NamedArg(_, arg) => check(arg) + case _ => + t.tpe.stripped match + case _: SingletonType => true + // We need to handle type refs for these test cases: + // - tests/pos/dependent-annot.scala + // - tests/pos/i16208.scala + // - tests/run/java-ann-super-class.scala + // - tests/run/java-ann-super-class-separate.scala + // - tests/neg/i19470.scala (@retains) + // Why do we get type refs in these cases? + case _: TypeRef => true + case _: TypeParamRef => true + case tp => false + + val uncheckedAnnots = Set[Symbol](defn.RetainsAnnot, defn.RetainsByNameAnnot) + if uncheckedAnnots(symbol) then return + + for arg <- arguments if !check(arg) do + report.error( + s"""Implementation restriction: not a valid type annotation argument. + | Argument: $arg + | Type: ${arg.tpe}""".stripMargin, arg.srcPos) + + () + /** Operations for hash-consing, can be overridden */ def hash: Int = System.identityHashCode(this) def eql(that: Annotation) = this eq that diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 8a9d44cb8d25..71296db08824 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -5754,6 +5754,7 @@ object Types extends TypeUtils { def make(underlying: Type, annots: List[Annotation])(using Context): Type = annots.foldLeft(underlying)(apply(_, _)) def apply(parent: Type, annot: Annotation)(using Context): AnnotatedType = + annot.checkValidTypeAnnotation() unique(CachedAnnotatedType(parent, annot)) end AnnotatedType diff --git a/tests/neg/annot-forbidden-type-annotations.scala b/tests/neg/annot-forbidden-type-annotations.scala new file mode 100644 index 000000000000..bc3bab30f634 --- /dev/null +++ b/tests/neg/annot-forbidden-type-annotations.scala @@ -0,0 +1,16 @@ +class annot[T](arg: T) extends scala.annotation.Annotation + +def m1(x: Any): Unit = () +val x: Int = ??? + +def m2(): String @annot(m1(2)) = ??? // error +def m3(): String @annot(throw new Error()) = ??? // error +def m4(): String @annot((x: Int) => x) = ??? // error +def m5(): String @annot(x + 1) = ??? // error + +def main = + @annot(m1(2)) val x1: String = ??? // ok + @annot(throw new Error()) val x2: String = ??? // ok + @annot((x: Int) => x) val x3: String = ??? // ok + @annot(x + 1) val x4: String = ??? // ok + () diff --git a/tests/pos/annot-17939b.scala b/tests/pos/annot-17939b.scala deleted file mode 100644 index a48f4690d0b2..000000000000 --- a/tests/pos/annot-17939b.scala +++ /dev/null @@ -1,10 +0,0 @@ -import scala.annotation.Annotation -class myRefined(f: ? => Boolean) extends Annotation - -def test(axes: Int) = true - -trait Tensor: - def mean(axes: Int): Int @myRefined(_ => test(axes)) - -class TensorImpl() extends Tensor: - def mean(axes: Int) = ??? diff --git a/tests/pos/annotDepMethType.scala b/tests/pos/annotDepMethType.scala index 079ca6224cea..ca8c9be55235 100644 --- a/tests/pos/annotDepMethType.scala +++ b/tests/pos/annotDepMethType.scala @@ -3,5 +3,4 @@ case class pc(calls: Any*) extends annotation.TypeConstraint object Main { class C0 { def baz: String = "" } class C1 { def bar(c0: C0): String @pc(c0.baz) = c0.baz } - def trans(c1: C1): String @pc(c1.bar(throw new Error())) = c1.bar(new C0) } diff --git a/tests/printing/annot-19846b.check b/tests/printing/annot-19846b.check deleted file mode 100644 index 3f63a46c4286..000000000000 --- a/tests/printing/annot-19846b.check +++ /dev/null @@ -1,33 +0,0 @@ -[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala -package { - class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(), - annotation.StaticAnnotation { - private[this] val g: () => Int - } - final lazy module val Test: Test = new Test() - final module class Test() extends Object() { this: Test.type => - val y: Int = ??? - val z: - Int @lambdaAnnot( - { - def $anonfun(): Int = Test.y - closure($anonfun) - } - ) - = f(Test.y) - } - final lazy module val annot-19846b$package: annot-19846b$package = - new annot-19846b$package() - final module class annot-19846b$package() extends Object() { - this: annot-19846b$package.type => - def f(x: Int): - Int @lambdaAnnot( - { - def $anonfun(): Int = x - closure($anonfun) - } - ) - = x - } -} - diff --git a/tests/printing/annot-19846b.scala b/tests/printing/annot-19846b.scala deleted file mode 100644 index 951a3c8116ff..000000000000 --- a/tests/printing/annot-19846b.scala +++ /dev/null @@ -1,7 +0,0 @@ -class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation - -def f(x: Int): Int @lambdaAnnot(() => x) = x - -object Test: - val y: Int = ??? - val z /* : Int @lambdaAnnot(() => y) */ = f(y)