Skip to content

Commit

Permalink
Three handy predicates to check classpath (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
da-tubi authored Mar 31, 2022
1 parent b618b9b commit 56fd8fb
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 9 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,41 @@ optimizedBuffer.reduceToSize(1)
```

You can define a `c` parameter because the `enableIf` annotation accepts either a `Boolean` expression or a `scala.reflect.macros.Context => Boolean` function. You can extract information from the macro context `c`.

## Enable different code for Apache Spark 3.1.x and 3.2.x
For breaking API changes of 3rd-party libraries, simply annotate the target method with the artifactId and the version to make it compatible.

To distinguish Apache Spark 3.1.x and 3.2.x:
``` scala
object XYZ {
@enableIf(classpathMatches(".*spark-catalyst_2\\.\\d+-3\\.2\\..*".r))
private def getFuncName(f: UnresolvedFunction): String = {
// For Spark 3.2.x
f.nameParts.last
}

@enableIf(classpathMatches(".*spark-catalyst_2\\.\\d+-3\\.1\\..*".r))
private def getFuncName(f: UnresolvedFunction): String = {
// For Spark 3.1.x
f.name.funcName
}
}
```

For specific Apache Spark versions:
``` scala
@enableIf(classpathMatchesArtifact(crossScalaBinaryVersion("spark-catalyst"), "3.2.1"))
@enableIf(classpathMatchesArtifact(crossScalaBinaryVersion("spark-catalyst"), "3.1.2"))
```

> NOTICE: `classpathMatchesArtifact` is for classpath without classifiers. For classpath with classifiers like
> `ffmpeg-5.0-1.5.7-android-arm-gpl.jar`, Please use `classpathMactches` or `classpathContains`.

Hints to show the full classpath:
``` bash
sbt "show Compile / fullClasspath"

mill show foo.compileClasspath
```

50 changes: 41 additions & 9 deletions src/main/scala/com/thoughtworks/enableIf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,43 @@ package com.thoughtworks
import scala.annotation.StaticAnnotation
import scala.reflect.internal.annotations.compileTimeOnly
import scala.reflect.macros.Context
import scala.util.matching.Regex


object enableIf {
val classpathRegex = "(.*)/([^/]*)-([^/]*)\\.jar".r

def crossScalaBinaryVersion(artifactId: String): String = {
val scalaBinaryVersion = scala.util.Properties
.versionNumberString
.split("\\.").take(2)
.mkString(".")
s"${artifactId}_${scalaBinaryVersion}"
}

def crossScalaFullVersion(artifactId: String): String = {
val scalaFullVersion = scala.util.Properties.versionNumberString
s"${artifactId}_${scalaFullVersion}"
}

def classpathContains(classpathPart: String): Context => Boolean = {
c => c.classPath.exists(_.getPath.contains(classpathPart))
}

def classpathMatches(regex: Regex): Context => Boolean = {
c => c.classPath.exists { dep =>
regex.pattern.matcher(dep.getPath).matches()
}
}

def classpathMatchesArtifact(artifactId: String, version: String): Context => Boolean = {
c => c.classPath.exists { dep =>
classpathRegex.findAllMatchIn(dep.getPath).exists { m =>
artifactId.equals(m.group(2)) && version.equals(m.group(3))
}
}
}


def isEnabled(c: Context, booleanCondition: Boolean) = booleanCondition

Expand All @@ -14,15 +49,12 @@ object enableIf {
private[enableIf] object Macros {
def macroTransform(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val Apply(Select(Apply(_, List(condition)), _), List(_ @_*)) =
c.macroApplication
if (
c.eval(c.Expr[Boolean](q"""
_root_.com.thoughtworks.enableIf.isEnabled(${reify(
c
).tree}, $condition)
"""))
) {
val Apply(Select(Apply(_, List(condition)), _), List(_@_*)) = c.macroApplication
if (c.eval(c.Expr[Boolean](
q"""
import _root_.com.thoughtworks.enableIf._
_root_.com.thoughtworks.enableIf.isEnabled(${reify(c).tree}, $condition)
"""))) {
c.Expr(q"..${annottees.map(_.tree)}")
} else {
c.Expr(EmptyTree)
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/thoughtworks/enableMembersIf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package com.thoughtworks
import scala.annotation.StaticAnnotation
import scala.reflect.internal.annotations.compileTimeOnly
import scala.reflect.macros.Context
import scala.util.matching.Regex


object enableMembersIf {

Expand All @@ -20,6 +22,7 @@ object enableMembersIf {
c.macroApplication
if (
c.eval(c.Expr[Boolean](q"""
import _root_.com.thoughtworks.enableIf._
_root_.com.thoughtworks.enableIf.isEnabled(${reify(
c
).tree}, $condition)
Expand Down
28 changes: 28 additions & 0 deletions src/test/scala/com/thoughtworks/EnableMembersIfTest.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.thoughtworks

import com.thoughtworks.enableIf.{classpathMatches, classpathMatchesArtifact, crossScalaBinaryVersion}
import org.scalatest._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -44,4 +45,31 @@ class EnableMembersIfTest extends AnyFreeSpec with Matchers {
assert(whichIsEnabled == "good")

}

"Test Artifact and " in {
@enableMembersIf(classpathMatchesArtifact(crossScalaBinaryVersion("quasiquotes"), "2.1.1"))
object ShouldEnable {
def whichIsEnabled = "good"
}

@enableMembersIf(classpathMatches(".*scala-library-2\\.1[123]\\..*".r))
object ShouldDisable1 {
def whichIsEnabled = "bad"
}

@enableMembersIf(classpathMatches(".*scala-2\\.1[123]\\..*".r))
object ShouldDisable2 {
def whichIsEnabled = "bad"
}

import ShouldEnable._
import ShouldDisable1._
import ShouldDisable2._

if (scala.util.Properties.versionNumberString < "2.11") {
assert(whichIsEnabled == "good")
} else {
assert(whichIsEnabled == "bad")
}
}
}
86 changes: 86 additions & 0 deletions src/test/scala/com/thoughtworks/EnableWithArtifactTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package com.thoughtworks

import org.scalatest._
import enableIf._

import scala.util.control.TailCalls._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers


/**
* @author 沈达 (Darcy Shen) &lt;[email protected]&gt;
*/
class EnableWithArtifactTest extends AnyFreeSpec with Matchers {
"test the constant regex of classpath" in {
assert {
"/path/to/scala-library-2.10.8.jar" match {
case classpathRegex(_, artifactId, version) =>
"scala-library".equals(artifactId) && "2.10.8".equals(version)
}
}
assert {
"/path/to/quasiquotes_2.10-2.1.1.jar" match {
case classpathRegex(_, artifactId, version) =>
"quasiquotes_2.10".equals(artifactId) && "2.1.1".equals(version)
}
}
}

"Test if we are using quasiquotes explicitly" in {

object ExplicitQ {

@enableIf(classpathMatchesArtifact(crossScalaBinaryVersion("quasiquotes"), "2.1.1"))
def whichIsEnabled = "good"
}
object ImplicitQ {
@enableIf(classpathMatches(".*scala-library-2\\.1[123]\\..*".r))
def whichIsEnabled = "bad"

@enableIf(classpathMatches(".*scala-2\\.1[123]\\..*".r))
def whichIsEnabled = "bad"
}


import ExplicitQ._
import ImplicitQ._
if (scala.util.Properties.versionNumberString < "2.11") {
assert(whichIsEnabled == "good")
} else {
assert(whichIsEnabled == "bad")
}
}

"Add TailRec.flatMap for Scala 2.10 " in {

@enableIf(classpathMatches(".*scala-library-2\\.10.*".r))
implicit class FlatMapForTailRec[A](underlying: TailRec[A]) {
final def flatMap[B](f: A => TailRec[B]): TailRec[B] = {
tailcall(f(underlying.result))
}
}

def ten = done(10)

def tenPlusOne = ten.flatMap(i => done(i + 1))

assert(tenPlusOne.result == 11)
}

"Add TailRec.flatMap for Scala 2.10 via classpathContains " in {

@enableIf(classpathContains("scala-library-2.10."))
implicit class FlatMapForTailRec[A](underlying: TailRec[A]) {
final def flatMap[B](f: A => TailRec[B]): TailRec[B] = {
tailcall(f(underlying.result))
}
}

def ten = done(10)

def tenPlusOne = ten.flatMap(i => done(i + 1))

assert(tenPlusOne.result == 11)
}
}
51 changes: 51 additions & 0 deletions src/test/scala/com/thoughtworks/EnableWithClasspathTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.thoughtworks

import org.scalatest._
import enableIf._

import scala.util.control.TailCalls._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers


/**
* @author 沈达 (Darcy Shen) &lt;[email protected]&gt;
*/
class EnableWithClasspathTest extends AnyFreeSpec with Matchers {

"enableWithClasspath by regex" in {

object ShouldEnable {

@enableIf(classpathMatches(".*scala.*".r))
def whichIsEnabled = "good"

}
object ShouldDisable {

@enableIf(classpathMatches(".*should_not_exist.*".r))
def whichIsEnabled = "bad"
}

import ShouldEnable._
import ShouldDisable._
assert(whichIsEnabled == "good")

}

"Add TailRec.flatMap for Scala 2.10 " in {

@enableIf(classpathMatches(".*scala-library-2.10.*".r))
implicit class FlatMapForTailRec[A](underlying: TailRec[A]) {
final def flatMap[B](f: A => TailRec[B]): TailRec[B] = {
tailcall(f(underlying.result))
}
}

def ten = done(10)

def tenPlusOne = ten.flatMap(i => done(i + 1))

assert(tenPlusOne.result == 11)
}
}

0 comments on commit 56fd8fb

Please sign in to comment.