Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADAM-1576] Allow translation between two different GenomicRDD types. #1598

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
763 changes: 756 additions & 7 deletions adam-core/src/main/scala/org/bdgenomics/adam/rdd/ADAMContext.scala

Large diffs are not rendered by default.

35 changes: 31 additions & 4 deletions adam-core/src/main/scala/org/bdgenomics/adam/rdd/GenomicRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ trait GenomicRDD[T, U <: GenomicRDD[T, U]] extends Logging {
replaceRdd(tFn(rdd))
}

/**
* Applies a function that transmutes the underlying RDD into a new RDD of a
* different type.
*
* @param tFn A function that transforms the underlying RDD.
* @return A new RDD where the RDD of genomic data has been replaced, but the
* metadata (sequence dictionary, and etc) is copied without modification.
*/
def transmute[X, Y <: GenomicRDD[X, Y]](tFn: RDD[T] => RDD[X])(
implicit convFn: (U, RDD[X]) => Y): Y = {
convFn(this.asInstanceOf[U], tFn(rdd))
}

// The partition map is structured as follows:
// The outer option is for whether or not there is a partition map.
// - This is None in the case that we don't know the bounds on each
Expand Down Expand Up @@ -918,12 +931,12 @@ trait GenomicRDD[T, U <: GenomicRDD[T, U]] extends Logging {
* @return Returns a new genomic RDD containing all pairs of keys that
* overlapped in the genomic coordinate space.
*/
def shuffleRegionJoin[X, Y <: GenomicRDD[X, Y], Z <: GenomicRDD[(T, X), Z]](
def shuffleRegionJoin[X, Y <: GenomicRDD[X, Y]](
genomicRdd: GenomicRDD[X, Y],
optPartitions: Option[Int] = None)(
implicit tTag: ClassTag[T],
xTag: ClassTag[X],
txTag: ClassTag[(T, X)]): GenomicRDD[(T, X), Z] = InnerShuffleJoin.time {
txTag: ClassTag[(T, X)]): GenericGenomicRDD[(T, X)] = InnerShuffleJoin.time {

val (leftRddToJoin, rightRddToJoin) =
prepareForShuffleRegionJoin(genomicRdd, optPartitions)
Expand All @@ -937,7 +950,7 @@ trait GenomicRDD[T, U <: GenomicRDD[T, U]] extends Logging {
combinedSequences,
kv => {
getReferenceRegions(kv._1) ++ genomicRdd.getReferenceRegions(kv._2)
}).asInstanceOf[GenomicRDD[(T, X), Z]]
})
}

/**
Expand Down Expand Up @@ -1281,7 +1294,7 @@ trait GenomicRDD[T, U <: GenomicRDD[T, U]] extends Logging {
}
}

private case class GenericGenomicRDD[T](
case class GenericGenomicRDD[T] private[rdd] (
rdd: RDD[T],
sequences: SequenceDictionary,
regionFn: T => Seq[ReferenceRegion],
Expand Down Expand Up @@ -1383,6 +1396,20 @@ trait GenomicDataset[T, U <: Product, V <: GenomicDataset[T, U, V]] extends Geno
* metadata (sequence dictionary, and etc) is copied without modification.
*/
def transformDataset(tFn: Dataset[U] => Dataset[U]): V

/**
* Applies a function that transmutes the underlying RDD into a new RDD of a
* different type.
*
* @param tFn A function that transforms the underlying RDD.
* @return A new RDD where the RDD of genomic data has been replaced, but the
* metadata (sequence dictionary, and etc) is copied without modification.
*/
def transmuteDataset[X <: Product, Y <: GenomicDataset[_, X, Y]](
tFn: Dataset[U] => Dataset[X])(
implicit convFn: (V, Dataset[X]) => Y): Y = {
convFn(this.asInstanceOf[V], tFn(dataset))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,85 @@ package org.bdgenomics.adam.rdd.contig
import java.io.File

import com.google.common.io.Files
import org.apache.spark.sql.SQLContext
import org.bdgenomics.adam.models._
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.feature.{ CoverageRDD, FeatureRDD }
import org.bdgenomics.adam.rdd.fragment.FragmentRDD
import org.bdgenomics.adam.rdd.read.AlignmentRecordRDD
import org.bdgenomics.adam.rdd.variant.{
GenotypeRDD,
VariantRDD,
VariantContextRDD
}
import org.bdgenomics.adam.sql.{
AlignmentRecord => AlignmentRecordProduct,
Feature => FeatureProduct,
Fragment => FragmentProduct,
Genotype => GenotypeProduct,
NucleotideContigFragment => NucleotideContigFragmentProduct,
Variant => VariantProduct
}
import org.bdgenomics.adam.util.ADAMFunSuite
import org.bdgenomics.formats.avro._
import scala.collection.mutable.ListBuffer

object NucleotideContigFragmentRDDSuite extends Serializable {

def covFn(ncf: NucleotideContigFragment): Coverage = {
Coverage(ncf.getContigName,
ncf.getStart,
ncf.getEnd,
1)
}

def featFn(ncf: NucleotideContigFragment): Feature = {
Feature.newBuilder
.setContigName(ncf.getContigName)
.setStart(ncf.getStart)
.setEnd(ncf.getEnd)
.build
}

def fragFn(ncf: NucleotideContigFragment): Fragment = {
Fragment.newBuilder
.setReadName(ncf.getContigName)
.build
}

def genFn(ncf: NucleotideContigFragment): Genotype = {
Genotype.newBuilder
.setContigName(ncf.getContigName)
.setStart(ncf.getStart)
.setEnd(ncf.getEnd)
.build
}

def readFn(ncf: NucleotideContigFragment): AlignmentRecord = {
AlignmentRecord.newBuilder
.setContigName(ncf.getContigName)
.setStart(ncf.getStart)
.setEnd(ncf.getEnd)
.build
}

def varFn(ncf: NucleotideContigFragment): Variant = {
Variant.newBuilder
.setContigName(ncf.getContigName)
.setStart(ncf.getStart)
.setEnd(ncf.getEnd)
.build
}

def vcFn(ncf: NucleotideContigFragment): VariantContext = {
VariantContext(Variant.newBuilder
.setContigName(ncf.getContigName)
.setStart(ncf.getStart)
.setEnd(ncf.getEnd)
.build)
}
}

class NucleotideContigFragmentRDDSuite extends ADAMFunSuite {

sparkTest("union two ncf rdds together") {
Expand Down Expand Up @@ -553,4 +626,189 @@ class NucleotideContigFragmentRDDSuite extends ADAMFunSuite {
optPredicate = Some(ReferenceRegion("HLA-DQB1*05:01:01:02", 500L, 1500L).toPredicate))
assert(fragments3.rdd.count === 2)
}

sparkTest("transform contigs to coverage rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(coverage: CoverageRDD) {
val tempPath = tmpLocation(".bed")
coverage.save(tempPath, false, false)

assert(sc.loadCoverage(tempPath).rdd.count === 8)
}

val coverage: CoverageRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.covFn)
})

checkSave(coverage)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val coverageDs: CoverageRDD = contigs.transmuteDataset(ds => {
ds.map(r => NucleotideContigFragmentRDDSuite.covFn(r.toAvro))
})

checkSave(coverageDs)
}

sparkTest("transform contigs to feature rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(features: FeatureRDD) {
val tempPath = tmpLocation(".bed")
features.saveAsBed(tempPath)

assert(sc.loadFeatures(tempPath).rdd.count === 8)
}

val features: FeatureRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.featFn)
})

checkSave(features)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val featuresDs: FeatureRDD = contigs.transmuteDataset(ds => {
ds.map(r => {
FeatureProduct.fromAvro(
NucleotideContigFragmentRDDSuite.featFn(r.toAvro))
})
})

checkSave(featuresDs)
}

sparkTest("transform contigs to fragment rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(fragments: FragmentRDD) {
val tempPath = tmpLocation(".adam")
fragments.saveAsParquet(tempPath)

assert(sc.loadFragments(tempPath).rdd.count === 8)
}

val fragments: FragmentRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.fragFn)
})

checkSave(fragments)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val fragmentsDs: FragmentRDD = contigs.transmuteDataset(ds => {
ds.map(r => {
FragmentProduct.fromAvro(
NucleotideContigFragmentRDDSuite.fragFn(r.toAvro))
})
})

checkSave(fragmentsDs)
}

sparkTest("transform contigs to read rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(reads: AlignmentRecordRDD) {
val tempPath = tmpLocation(".adam")
reads.saveAsParquet(tempPath)

assert(sc.loadAlignments(tempPath).rdd.count === 8)
}

val reads: AlignmentRecordRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.readFn)
})

checkSave(reads)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val readsDs: AlignmentRecordRDD = contigs.transmuteDataset(ds => {
ds.map(r => {
AlignmentRecordProduct.fromAvro(
NucleotideContigFragmentRDDSuite.readFn(r.toAvro))
})
})

checkSave(readsDs)
}

sparkTest("transform contigs to genotype rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(genotypes: GenotypeRDD) {
val tempPath = tmpLocation(".adam")
genotypes.saveAsParquet(tempPath)

assert(sc.loadGenotypes(tempPath).rdd.count === 8)
}

val genotypes: GenotypeRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.genFn)
})

checkSave(genotypes)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val genotypesDs: GenotypeRDD = contigs.transmuteDataset(ds => {
ds.map(r => {
GenotypeProduct.fromAvro(
NucleotideContigFragmentRDDSuite.genFn(r.toAvro))
})
})

checkSave(genotypesDs)
}

sparkTest("transform contigs to variant rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(variants: VariantRDD) {
val tempPath = tmpLocation(".adam")
variants.saveAsParquet(tempPath)

assert(sc.loadVariants(tempPath).rdd.count === 8)
}

val variants: VariantRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.varFn)
})

checkSave(variants)

val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._

val variantsDs: VariantRDD = contigs.transmuteDataset(ds => {
ds.map(r => {
VariantProduct.fromAvro(
NucleotideContigFragmentRDDSuite.varFn(r.toAvro))
})
})

checkSave(variantsDs)
}

sparkTest("transform contigs to variant context rdd") {
val contigs = sc.loadFasta(testFile("HLA_DQB1_05_01_01_02.fa"), 1000L)

def checkSave(variantContexts: VariantContextRDD) {
assert(variantContexts.rdd.count === 8)
}

val variantContexts: VariantContextRDD = contigs.transmute(rdd => {
rdd.map(NucleotideContigFragmentRDDSuite.vcFn)
})

checkSave(variantContexts)
}
}
Loading