Skip to content

Commit

Permalink
Increase jitter for AS_MQ to solve production pipeline convergence
Browse files Browse the repository at this point in the history
problems for exomes with AS annotations
Add VQSR debug arg
  • Loading branch information
ldgauthier committed Nov 18, 2019
1 parent a632a05 commit e6de27a
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,16 @@ public UnimplementedFeature(String message){
super(message);
}
}

public static final class VQSRPositiveModelFailure extends UserException {
private static final long serialVersionUID = 0L;

public VQSRPositiveModelFailure(String message) { super(message); }
}

public static final class VQSRNegativeModelFailure extends UserException {
private static final long serialVersionUID = 0L;

public VQSRNegativeModelFailure(String message) { super(message); }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.commons.math3.special.Gamma;

import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.MathUtils;

Expand Down Expand Up @@ -89,8 +90,7 @@ public void divideEqualsMu( final double x ) {
private void precomputeInverse() {
try {
cachedSigmaInverse = sigma.inverse();
} catch( Exception e ) {
//TODO: there must be something narrower than Exception to catch here
} catch( RuntimeException e ) {
throw new UserException(
"Error during clustering. Most likely there are too few variants used during Gaussian mixture " +
"modeling. Please consider raising the number of variants used to train the negative "+
Expand All @@ -104,6 +104,9 @@ private void precomputeInverse() {
public void precomputeDenominatorForEvaluation() {
precomputeInverse();
cachedDenomLog10 = Math.log10(Math.pow(2.0 * Math.PI, -1.0 * ((double) mu.length) / 2.0)) + Math.log10(Math.pow(sigma.det(), -0.5)) ;
if (Double.isNaN(cachedDenomLog10)) {
throw new GATKException("Denominator for gaussian evaluation cannot be computed. One or more annotations (usually MQ) may have insufficient variance.");
}
}

public void precomputeDenominatorForVariationalBayes( final double sumHyperParameterLambda ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.vcf.VCFConstants;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.AnnotationUtils;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;

Expand All @@ -26,6 +27,7 @@ public class VariantDataManager {
private double[] varianceVector; // this is really the standard deviation
public List<String> annotationKeys;
private final VariantRecalibratorArgumentCollection VRAC;
private final boolean debugStdevThresholding;
protected final static Logger logger = LogManager.getLogger(VariantDataManager.class);
protected final List<TrainingSet> trainingSets;
private static final double SAFETY_OFFSET = 0.01; //To use for example as 1/(X + SAFETY_OFFSET) to protect against dividing or taking log of X=0.
Expand All @@ -38,6 +40,7 @@ public VariantDataManager( final List<String> annotationKeys, final VariantRecal
meanVector = new double[this.annotationKeys.size()];
varianceVector = new double[this.annotationKeys.size()];
trainingSets = new ArrayList<>();
debugStdevThresholding = VRAC.debugStdevThresholding;
}

public void setData( final List<VariantDatum> data ) {
Expand Down Expand Up @@ -224,6 +227,8 @@ public List<VariantDatum> getTrainingData() {
for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && !datum.failingSTDThreshold ) {
trainingData.add( datum );
} else if (datum.failingSTDThreshold && debugStdevThresholding) {
logger.warn("Datum at " + datum.loc + " with ref " + datum.referenceAllele + " and alt " + datum.alternateAllele + " failing std thresholding: " + Arrays.toString(datum.annotations));
}
}
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
Expand Down Expand Up @@ -352,7 +357,7 @@ private static double decodeAnnotation( final String annotationKey, final Varian
if( jitter && (annotationKey.equalsIgnoreCase(GATKVCFConstants.FISHER_STRAND_KEY) || annotationKey.equalsIgnoreCase(GATKVCFConstants.AS_FILTER_STATUS_KEY)) && MathUtils.compareDoubles(value, 0.0, PRECISION) == 0 ) { value += 0.01 * Utils.getRandomGenerator().nextGaussian(); }
if( jitter && annotationKey.equalsIgnoreCase(GATKVCFConstants.INBREEDING_COEFFICIENT_KEY) && MathUtils.compareDoubles(value, 0.0, PRECISION) == 0 ) { value += 0.01 * Utils.getRandomGenerator().nextGaussian(); }
if( jitter && (annotationKey.equalsIgnoreCase(GATKVCFConstants.STRAND_ODDS_RATIO_KEY) || annotationKey.equalsIgnoreCase(GATKVCFConstants.AS_STRAND_ODDS_RATIO_KEY)) && MathUtils.compareDoubles(value, LOG_OF_TWO, PRECISION) == 0 ) { value += 0.01 * Utils.getRandomGenerator().nextGaussian(); } //min SOR is 2.0, then we take ln
if( jitter && (annotationKey.equalsIgnoreCase(VCFConstants.RMS_MAPPING_QUALITY_KEY) || annotationKey.equalsIgnoreCase(GATKVCFConstants.AS_RMS_MAPPING_QUALITY_KEY))) {
if( jitter && (annotationKey.equalsIgnoreCase(VCFConstants.RMS_MAPPING_QUALITY_KEY))) {
if( vrac.MQ_CAP > 0) {
value = logitTransform(value, -SAFETY_OFFSET, vrac.MQ_CAP + SAFETY_OFFSET);
if (MathUtils.compareDoubles(value, logitTransform(vrac.MQ_CAP, -SAFETY_OFFSET, vrac.MQ_CAP + SAFETY_OFFSET), PRECISION) == 0 ) {
Expand All @@ -362,9 +367,11 @@ private static double decodeAnnotation( final String annotationKey, final Varian
value += vrac.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
}
}
} catch( Exception e ) {
//TODO: what exception is this handling ? it seems overly broad
value = Double.NaN; // The VQSR works with missing data by marginalizing over the missing dimension when evaluating the Gaussian mixture model
if( jitter && (annotationKey.equalsIgnoreCase(GATKVCFConstants.AS_RMS_MAPPING_QUALITY_KEY))){
value += vrac.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
}
} catch( NumberFormatException e ) {
value = Double.NaN; // VQSR works with missing data by marginalizing over the missing dimension when evaluating the Gaussian mixture model
}

return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,14 +650,20 @@ public Object onTraversalSuccess() {
// Generate the positive model using the training data and evaluate each variant
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
engine.evaluateData(dataManager.getData(), goodModel, false);
if (goodModel.failedToConverge) {
throw new UserException.VQSRPositiveModelFailure("Positive training model failed to converge. One or more annotations " +
"(usually MQ) may have insufficient variance. Please consider lowering the maximum number" +
" of Gaussians allowed for use in the model (via --max-gaussians 4, for example).");
}
// Generate the negative model using the worst performing data and evaluate each variant contrastively
negativeTrainingData = dataManager.selectWorstVariants();
badModel = engine.generateModel(negativeTrainingData,
Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));

if (badModel.failedToConverge || goodModel.failedToConverge) {
throw new UserException(
"NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minimum-bad-variants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --max-gaussians 4, for example)."));
if (badModel.failedToConverge) {
throw new UserException.VQSRNegativeModelFailure(
"NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe." +
" Please consider raising the number of variants used to train the negative model " +
"(via --minimum-bad-variants 5000, for example).");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,9 @@ public enum Mode {
@Argument(fullName="mq-jitter", doc="Amount of jitter (as a factor to a Normal(0,1) noise) to add to the MQ capped values", optional = true)
public double MQ_JITTER = 0.05;

@Hidden
@Advanced
@Argument(fullName = "debug-stdev-thresholding", doc="Output variants that fail standard deviation thresholding to the log for debugging purposes.", optional = true)
public boolean debugStdevThresholding = false;

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;

import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.Utils;

import java.util.List;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void evaluateData( final List<VariantDatum> data, final GaussianMixtureMo
if( !model.isModelReadyForEvaluation ) {
try {
model.precomputeDenominatorForEvaluation();
} catch( Exception e ) {
} catch( GATKException e ) {
logger.warn("Model could not pre-compute denominators."); //this happened when we were reading in VQSR models that didn't have enough precision
model.failedToConverge = true;
return;
Expand All @@ -63,7 +64,6 @@ public void evaluateData( final List<VariantDatum> data, final GaussianMixtureMo
for( final VariantDatum datum : data ) {
final double thisLod = evaluateDatum( datum, model );
if( Double.isNaN(thisLod) ) {
logger.warn("Evaluate datum returned a NaN.");
model.failedToConverge = true;
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,38 @@
*/
public class VariantRecalibratorIntegrationTest extends CommandLineProgramTest {

private final String[] VQSRParamsWithResources =
new String[] {
"--variant",
getLargeVQSRTestDataDir() + "phase1.projectConsensus.chr20.1M-10M.raw.snps.vcf",
"-L","20:1,000,000-10,000,000",
"--resource:known,known=true,prior=10.0",
getLargeVQSRTestDataDir() + "dbsnp_132_b37.leftAligned.20.1M-10M.vcf",
"--resource:truth_training1,truth=true,training=true,prior=15.0",
getLargeVQSRTestDataDir() + "sites_r27_nr.b37_fwd.20.1M-10M.vcf",
"--resource:truth_training2,training=true,truth=true,prior=12.0",
getLargeVQSRTestDataDir() + "Omni25_sites_1525_samples.b37.20.1M-10M.vcf",
"-an", "QD", "-an", "HaplotypeScore", "-an", "HRun",
"--trust-all-polymorphic", // for speed
"-mode", "SNP",
"--" + StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false"
};

private final String[] alleleSpecificVQSRParams =
new String[] {
"--variant",
getLargeVQSRTestDataDir() + "chr1snippet.doctoredMQ.sites_only.vcf.gz",
"-L","chr1:1-10,000,000",
"-resource:same,known=false,training=true,truth=true,prior=15",
getLargeVQSRTestDataDir() + "chr1snippet.doctoredMQ.sites_only.vcf.gz",
"-an", "AS_QD", "-an", "AS_ReadPosRankSum", "-an", "AS_MQ", "-an", "AS_SOR", //AS_MQRankSum has zero variance and AS_FS is nearly constant; also different annotation orders may not converge
"--trust-all-polymorphic", // for speed
"--use-allele-specific-annotations",
"-mode", "SNP",
"--" + StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false",
"--max-gaussians", "6"
};

@Override
public String getToolTestDataDir(){
return toolsTestDir + "walkers/VQSR/";
Expand All @@ -49,26 +81,46 @@ public void initializeVariantRecalTests() {
public Object[][] getVarRecalSNPData() {
return new Object[][] {
{
new String[] {
"--variant",
getLargeVQSRTestDataDir() + "phase1.projectConsensus.chr20.1M-10M.raw.snps.vcf",
"-L","20:1,000,000-10,000,000",
"--resource:known,known=true,prior=10.0",
getLargeVQSRTestDataDir() + "dbsnp_132_b37.leftAligned.20.1M-10M.vcf",
"--resource:truth_training1,truth=true,training=true,prior=15.0",
getLargeVQSRTestDataDir() + "sites_r27_nr.b37_fwd.20.1M-10M.vcf",
"--resource:truth_training2,training=true,truth=true,prior=12.0",
getLargeVQSRTestDataDir() + "Omni25_sites_1525_samples.b37.20.1M-10M.vcf",
"-an", "QD", "-an", "HaplotypeScore", "-an", "HRun",
"--trust-all-polymorphic", // for speed
"-mode", "SNP",
"--" + StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false"
}
VQSRParamsWithResources,
getLargeVQSRTestDataDir() + "expected/SNPDefaultTranches.txt",
getLargeVQSRTestDataDir() + "snpRecal.vcf"
},
{
alleleSpecificVQSRParams,
getToolTestDataDir() + "expected.AS.tranches",
getLargeVQSRTestDataDir() + "expected/expected.AS.recal.vcf"
}

};
}

private void doSNPTest(final String[] params, final String expectedTranchesFile) throws IOException {
@DataProvider(name="VarRecalSNPAlternateTranches")
public Object[][] getVarRecalSNPAlternateTranchesData() {
return new Object[][] {
{
VQSRParamsWithResources,
getLargeVQSRTestDataDir() + "expected/SNPAlternateTranches.txt",
getLargeVQSRTestDataDir() + "snpRecal.vcf"
},
{
alleleSpecificVQSRParams,
getToolTestDataDir() + "expected.AS.alternate.tranches",
getLargeVQSRTestDataDir() + "expected/expected.AS.recal.vcf"
}

};
}

@DataProvider(name="SNPRecalCommand")
public Object[][] getSNPRecalCommand() {
return new Object[][] {
{
VQSRParamsWithResources
}
};
}

private void doSNPTest(final String[] params, final String expectedTranchesFile, final String expectedRecalFile) throws IOException {
//NOTE: The number of iterations required to ensure we have enough negative training data to proceed,
//as well as the test results themselves, are both very sensitive to the state of the random number
//generator at the time the tool starts to execute. Sampling a single integer from the RNG at the
Expand Down Expand Up @@ -98,17 +150,17 @@ private void doSNPTest(final String[] params, final String expectedTranchesFile)

// the expected vcf is not in the expected dir because its used
// as input for the ApplyVQSR test
IntegrationTestSpec.assertEqualTextFiles(recalOut, new File(getLargeVQSRTestDataDir() + "snpRecal.vcf"));
IntegrationTestSpec.assertEqualTextFiles(recalOut, new File(expectedRecalFile));
IntegrationTestSpec.assertEqualTextFiles(tranchesOut, new File(expectedTranchesFile));
}

@Test(dataProvider = "VarRecalSNP")
public void testVariantRecalibratorSNP(final String[] params) throws IOException {
doSNPTest(params, getLargeVQSRTestDataDir() + "expected/SNPDefaultTranches.txt");
public void testVariantRecalibratorSNP(final String[] params, final String tranchesPath, final String recalPath) throws IOException {
doSNPTest(params, tranchesPath, recalPath);
}

@Test(dataProvider = "VarRecalSNP")
public void testVariantRecalibratorSNPAlternateTranches(final String[] params) throws IOException {
@Test(dataProvider = "VarRecalSNPAlternateTranches")
public void testVariantRecalibratorSNPAlternateTranches(final String[] params, final String tranchesPath, final String recalPath) throws IOException {
// same as testVariantRecalibratorSNP but with specific tranches
List<String> args = new ArrayList<>(params.length);
Stream.of(params).forEach(arg -> args.add(arg));
Expand All @@ -128,11 +180,11 @@ public void testVariantRecalibratorSNPAlternateTranches(final String[] params) t
"-tranche", "90.0"
)
);
doSNPTest(args.toArray(new String[args.size()]), getLargeVQSRTestDataDir() + "expected/SNPAlternateTranches.txt");
doSNPTest(args.toArray(new String[args.size()]), tranchesPath, recalPath);
}

@Test(dataProvider = "VarRecalSNP")
public void testVariantRecalibratorSNPMaxAttempts(final String[] params) throws IOException {
public void testVariantRecalibratorSNPMaxAttempts(final String[] params, final String a, final String b) throws IOException {
// For this test, we deliberately *DON'T* sample a single random int as above; this causes
// the tool to require 4 attempts to acquire enough negative training data to succeed

Expand Down Expand Up @@ -323,7 +375,7 @@ public Object[][] getVarRecalSNPScatteredData() {
@Test(dataProvider = "VarRecalSNPScattered")
//the only way the recal file will match here is if we use the doSNPTest infrastructure -- as an IntegrationTestSpec it doesn't match for some reason
public void testVariantRecalibratorSNPscattered(final String[] params) throws IOException {
doSNPTest(params, getLargeVQSRTestDataDir() + "/snpTranches.scattered.txt"); //this isn't in the expected/ directory because it's input to GatherTranchesIntegrationTest
doSNPTest(params, getLargeVQSRTestDataDir() + "/snpTranches.scattered.txt", getLargeVQSRTestDataDir() + "snpRecal.vcf"); //tranches file isn't in the expected/ directory because it's input to GatherTranchesIntegrationTest
}


Expand Down
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/test/resources/large/VQSR/expected/expected.AS.recal.vcf
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit e6de27a

Please sign in to comment.