Skip to content

Commit

Permalink
Merge pull request #393 from jeroentervoorde/scalashade
Browse files Browse the repository at this point in the history
Rewrite @ScalaSignature when shading
  • Loading branch information
eed3si9n authored May 23, 2020
2 parents 59a008f + 8879855 commit db69ab5
Show file tree
Hide file tree
Showing 23 changed files with 1,105 additions and 25 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ To see the verbose output for shading:
logLevel in assembly := Level.Debug
```

#### Scala libraries

Scala classes contain an annotation which, among other things, contain all symbols referenced in that class. As of sbt-assembly XXX the rename rules
will be applied to these annotations as well which makes it possible to compile or reflect against a shaded library.

This is currently limited to renaming packages. Renaming class names will not work and cause compiler errors when compiling against the shaded library.

Excluding JARs and files
------------------------

Expand Down
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ lazy val root = (project in file(".")).
scalacOptions := Seq("-deprecation", "-unchecked", "-Dscalac.patmat.analysisBudget=1024", "-Xfuture"),
libraryDependencies ++= Seq(
"org.scalactic" %% "scalactic" % "3.0.8",
"org.pantsbuild" % "jarjar" % "1.7.2"
"org.pantsbuild" % "jarjar" % "1.7.2",
"org.scalatest" %% "scalatest" % "3.1.1" % Test,
),
crossSbtVersions := Seq("0.13.18", "1.2.8"), // https://github.com/sbt/sbt/issues/5049
publishArtifact in (Compile, packageBin) := true,
Expand Down
100 changes: 84 additions & 16 deletions src/main/scala/org/pantsbuild/jarjar/JJProcessor.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,99 @@
package org.pantsbuild.jarjar

import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor}
import java.io.IOException

import org.pantsbuild.jarjar.misplaced.MisplacedClassProcessorFactory
import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor, JarProcessorChain, JarTransformerChain, RemappingClassTransformer, StandaloneJarProcessor}

import scala.collection.JavaConverters._
import scala.collection.mutable

/**
* Creates a new JJProcessor, which automatically generates the standard zap, keep, remap, etc processors.
* This is a copy of the MainProcessor in JarJar with an added ScalaSigProcessor
*
* @param patterns List of rules to parse.
* @param verbose Whether to verbosely log information.
* @param skipManifest If true, omits the manifest file from the processed jar.
* @param misplacedClassStrategy The strategy to use when processing class files that are in the
* wrong package (see MisplacedClassProcessorFactory.STRATEGY_* constants).
*/
class JJProcessor(val patterns: Seq[PatternElement], val verbose: Boolean, val skipManifest: Boolean, val misplacedClassStrategy: String) extends JarProcessor {

val zapList: Seq[Zap] = patterns.collect { case zap: Zap => zap }
val ruleList: Seq[Rule] = patterns.collect { case rule: Rule => rule }
val keepList: Seq[Keep] = patterns.collect { case keep: Keep => keep }
val renames: mutable.Map[String, String] = collection.mutable.HashMap[String, String]()

val kp: KeepProcessor = if (keepList.isEmpty) null else new KeepProcessor(keepList.asJava)

val pr = new PackageRemapper(ruleList.asJava, verbose)

class JJProcessor(val proc: JarProcessor) {
val processors: mutable.ArrayBuffer[JarProcessor] = collection.mutable.ArrayBuffer[JarProcessor]()
if (skipManifest)
processors += ManifestProcessor.getInstance
if (kp != null)
processors += kp

def process(entry: EntryStruct): Boolean = proc.process(entry)
val misplacedClassProcessor: JarProcessor = MisplacedClassProcessorFactory.getInstance.getProcessorForName(misplacedClassStrategy)
processors += new ZapProcessor(zapList.asJava)
processors += misplacedClassProcessor
processors += new JarTransformerChain(Array[RemappingClassTransformer](new RemappingClassTransformer(pr)))

def getExcludes(): Set[String] = {
val field = proc.getClass().getDeclaredField("kp")
field.setAccessible(true)
val keepProcessor = field.get(proc)
val renamer: String => Option[String] = {
val wildcards = PatternElement.createWildcards(ruleList.asJava).asScala

if (keepProcessor == null) Set()
else {
val method = proc.getClass().getDeclaredMethod("getExcludes")
method.setAccessible(true)
method.invoke(proc).asInstanceOf[java.util.Set[String]].asScala.toSet
value: String => {
val result = wildcards.flatMap {
wc =>
val slashed = value.replace('.', '/') // The jarjar wildcards expect slashes instead of dots
// Hack to replace the package object name.
val renamed = Option(wc.replace(slashed)).orElse(Option(wc.replace(slashed + "/")).map(_.dropRight(1)))
renamed.map(_.replace('/', '.')) // Unslash
}.headOption

result
}
}

}
processors += new ScalaSigProcessor(renamer)
processors += new MethodSignatureProcessor(pr)
processors += new ResourceProcessor(pr)
val chain = new JarProcessorChain(processors.toArray)

object JJProcessor {
@throws[IOException]
def strip(file: Nothing): Unit = {
if (kp != null) {
val excludes = getExcludes
if (excludes.nonEmpty) StandaloneJarProcessor.run(file, file, new ExcludeProcessor(excludes.asJava, verbose))
}
}

def apply(patterns: Seq[PatternElement], verbose: Boolean, skipManifest: Boolean): JJProcessor =
new JJProcessor(new MainProcessor(patterns.asJava, verbose, skipManifest))
/**
* Returns the <code>.class</code> files to delete. As well the root-parameter as the rename ones
* are taken in consideration, so that the concerned files are not listed in the result.
*
* @return the paths of the files in the jar-archive, including the <code>.class</code> suffix
*/
def getExcludes: Set[String] = if (kp != null) kp.getExcludes.asScala.map { exclude =>
val name = exclude + ".class"
renames.getOrElse(name, name)
}.toSet else Set.empty

/**
*
* @param struct entry struct to process
* @return <code>true</code> if the entry is to include in the output jar
* @throws IOException
*/
@throws[IOException]
def process(struct: EntryStruct): Boolean = {
val name = struct.name
val keepIt = chain.process(struct)
if (keepIt) if (!name.equals(struct.name)) {
if (kp != null) renames.put(name, struct.name)
if (verbose) System.err.println("Renamed " + name + " -> " + struct.name)
} else if (verbose) System.err.println("Removed " + name)
keepIt
}
}
20 changes: 20 additions & 0 deletions src/main/scala/org/pantsbuild/jarjar/ScalaSigProcessor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.pantsbuild.jarjar

import org.objectweb.asm.{ClassReader, ClassWriter}
import org.pantsbuild.jarjar.util.{EntryStruct, JarProcessor}
import sbtassembly.scalasig.ScalaSigClassVisitor

class ScalaSigProcessor(renamer: String => Option[String]) extends JarProcessor {
override def process(struct: EntryStruct): Boolean = {

if (!struct.name.endsWith(".class") || struct.skipTransform) true
else {
val classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS)
val reader = new ClassReader(struct.data)

reader.accept(new ScalaSigClassVisitor(classWriter, renamer), ClassReader.EXPAND_FRAMES)
struct.data = classWriter.toByteArray
true
}
}
}
9 changes: 4 additions & 5 deletions src/main/scala/sbtassembly/Shader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package sbtassembly

import java.io.File

import org.pantsbuild.jarjar._
import org.pantsbuild.jarjar.{JJProcessor, _}
import org.pantsbuild.jarjar.util.EntryStruct

import sbt._

case class ShadeRule(shadePattern: ShadePattern, targets: Seq[ShadeTarget] = Seq()) {
Expand Down Expand Up @@ -83,7 +82,7 @@ private[sbtassembly] object Shader {
case _ => Nil
}}

val proc = JJProcessor(jjrules, verbose = level == Level.Debug, true)
val proc = new JJProcessor(jjrules, verbose = level == Level.Debug, true, null)

/*
jarjar MisplacedClassProcessor class transforms byte[] to a class using org.objectweb.asm.ClassReader.getClassName
Expand All @@ -104,7 +103,7 @@ private[sbtassembly] object Shader {
IO.write(dir / entry.name, entry.data)
}
}
val excludes = proc.getExcludes()
val excludes = proc.getExcludes
excludes.foreach(exclude => IO.delete(dir / exclude))
}
}
}
22 changes: 22 additions & 0 deletions src/main/scala/sbtassembly/scalasig/ByteArrayReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package sbtassembly.scalasig

// Utility class to read the content of a single table entry
class ByteArrayReader(bytes: Array[Byte]) extends Nat.Reader {
private var readIndex = 0

/** Read a byte */
override def readByte(): Int = {
val x = bytes(readIndex).toInt
readIndex += 1
x
}

/** Reads a number of bytes into an array */
def readBytes(len: Int): Array[Byte] = {
val result = bytes.slice(readIndex, readIndex + len)
readIndex += len
result
}

def atEnd: Boolean = readIndex == bytes.length
}
150 changes: 150 additions & 0 deletions src/main/scala/sbtassembly/scalasig/EntryTable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package sbtassembly.scalasig

import java.io.ByteArrayOutputStream

import scala.collection.mutable
import scala.reflect.internal.pickling.PickleFormat

/**
* Mutable table of tagged entries
* @param majorVersion major table version
* @param minorVersion minor table version
* @param entries initial table entries
*/
class EntryTable(majorVersion: Int, minorVersion: Int, entries: mutable.Buffer[TaggedEntry]) {
// Mapping of known TermName or TypeNames to their index in the table.
private val nameIndices: mutable.Map[NameEntry, Int] = mutable.HashMap(
entries.zipWithIndex.collect {
case (entry: NameEntry, index) => (entry, index)
}:_*
)

/**
* Return the current table entries as an immutable seq.
* @return table entries
*/
def toSeq: Seq[TaggedEntry] = entries.toVector

/**
* Rename term and type entries in this table according to the renamer function.
* A name or type is referred to by a Ref entry. The existing ref entries are reused to references to them will remain intact.
* Unused entries will not be removed from the table.
*
* @param renamer renames a fully qualified type or term name or return None if it does not match.
*/
def renameEntries(renamer: String => Option[String]): Unit = {

entries.zipWithIndex.collect {
case (ref: RefEntry, index) =>
entries(ref.nameRef) match {
case nameEntry: NameEntry =>
for {
fqName <- resolveRef(ref)
renamed <- renamer(fqName)
} {
val parts = renamed.split('.')

val myOwner = parts.init.foldLeft(Option.empty[Int]) { (owner, part) =>
val nameIndex = getOrAppendNameEntry(NameEntry(PickleFormat.TERMname, part))
val nextOwner = appendEntry(RefEntry(PickleFormat.EXTMODCLASSref, nameIndex, owner))
Some(nextOwner)
}

entries(index) = ref.copy(nameRef = getOrAppendNameEntry(nameEntry.copy(name = parts.last)), ownerRef = myOwner)
}

case other =>
throw new RuntimeException(s"Ref entry does not point to a name but to a ${other.tag}")
}
}
}

// Return existing name entry or append a new one.
private def getOrAppendNameEntry(name: NameEntry): Int = {
nameIndices.getOrElse(name, appendEntry(name))
}

private def appendEntry(entry: TaggedEntry): Int = {
val index = entries.size
entries += entry

entry match {
case name: NameEntry =>
nameIndices.put(name, index)
case _ => // NoOp
}

index
}

// Resolves a ref into a fully qualified name
def resolveRef(extMod: RefEntry): Option[String] = {

val myName = entries(extMod.nameRef) match {
case term: NameEntry => term.name
case raw: RawEntry => throw new RuntimeException(s"Unexpected raw type for nameref ${raw.tag}")
case other => throw new RuntimeException(s"Unexpected type for nameref $other")
}
extMod.ownerRef match {
case None => Some(myName)
case Some(owner) =>
entries(owner) match {
case name: NameEntry =>
Some(s"$name/$myName")
case mod: RefEntry =>
resolveRef(mod).map(p => s"$p.$myName")
case raw: RawEntry if raw.tag == PickleFormat.NONEsym =>
None
case raw: RawEntry =>
throw new RuntimeException(s"Not a known owner type tag for $myName : ${raw.tag}")
}
}
}

/**
* Serializes this entry table into a byte array.
*/
def toBytes: Array[Byte] = {
val os = new ByteArrayOutputStream()
val writer = new Nat.Writer {
override def writeByte(b: Int): Unit = os.write(b)
}

writer.writeNat(majorVersion)
writer.writeNat(minorVersion)
writer.writeNat(entries.size)

entries.foreach { entry =>
val payloadBytes = entry.toBytes
writer.writeNat(entry.tag) // Tag of entry
writer.writeNat(payloadBytes.length) // Size of payload
os.write(payloadBytes)
}

os.toByteArray
}
}

object EntryTable {

/**
* Parse bytes into a EntryTable
*/
def fromBytes(bytes: Array[Byte]): EntryTable = {
val reader = new ByteArrayReader(bytes)

val majorVersion = reader.readNat()
val minorVersion = reader.readNat()

val result = new Array[TaggedEntry](reader.readNat())

result.indices foreach { index =>
val tag = reader.readNat()
val len = reader.readNat()

result(index) = TaggedEntry(tag, reader.readBytes(len))
}

new EntryTable(majorVersion, minorVersion, result.toBuffer)
}
}
Loading

0 comments on commit db69ab5

Please sign in to comment.