Skip to content

Commit

Permalink
Allow users to specify VQSLOD sensitivity and apply threshold in Extr…
Browse files Browse the repository at this point in the history
…actCohort (#7194)

* basic filtering on vqslod or sensitivity cutoffs

* move tranche parsing to ExtractCohort, add option to add extra headers

* refactor to FilterSensitivityTools, add tests

* flesh out tests, use BadInput exceptions

* clean up commented code

* nullable check for vqslod cutoffs
  • Loading branch information
mmorgantaylor authored Apr 12, 2021
1 parent 3a2cb47 commit a418fb1
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ public static VCFHeader generateRawArrayVcfHeader(Set<String> sampleNames, final
return header;
}

public static VCFHeader generateVcfHeader(Set<String> sampleNames,//) { //final Set<VCFHeaderLine> defaultHeaderLines,
final SAMSequenceDictionary sequenceDictionary) {
public static VCFHeader generateVcfHeader(Set<String> sampleNames,
final SAMSequenceDictionary sequenceDictionary,
final Set<VCFHeaderLine> extraHeaders) {
final Set<VCFHeaderLine> headerLines = new HashSet<>();

headerLines.addAll( getEvoquerVcfHeaderLines() );
headerLines.addAll( extraHeaders );
// headerLines.addAll( defaultHeaderLines );

final VCFHeader header = new VCFHeader(headerLines, sampleNames);
Expand All @@ -92,6 +94,13 @@ public static VCFHeader generateVcfHeader(Set<String> sampleNames,//) { //final
return header;
}

public static VCFHeader generateVcfHeader(Set<String> sampleNames,
final SAMSequenceDictionary sequenceDictionary) {

Set<VCFHeaderLine> noExtraHeaders = new HashSet<>();
return generateVcfHeader(sampleNames, sequenceDictionary, noExtraHeaders);
}


// TODO is this specific for cohort extract? if so name it such
public static Set<VCFHeaderLine> getEvoquerVcfHeaderLines() {
Expand Down Expand Up @@ -155,8 +164,6 @@ public static Set<VCFHeaderLine> getEvoquerVcfHeaderLines() {

headerLines.add(GATKVCFHeaderLines.getInfoLine(GATKVCFConstants.SB_TABLE_KEY));

headerLines.add(GATKVCFHeaderLines.getFilterLine(GATKVCFConstants.VQSR_TRANCHE_SNP));
headerLines.add(GATKVCFHeaderLines.getFilterLine(GATKVCFConstants.VQSR_TRANCHE_INDEL));
headerLines.add(GATKVCFHeaderLines.getFilterLine(GATKVCFConstants.NAY_FROM_YNG));
headerLines.add(GATKVCFHeaderLines.getFilterLine(GATKVCFConstants.EXCESS_HET_KEY));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ public class SchemaUtils {
public static final String TRAINING_LABEL = "training_label";
public static final String YNG_STATUS = "yng_status";

//Tranches table
public static final String TARGET_TRUTH_SENSITIVITY = "target_truth_sensitivity";
public static final String MIN_VQSLOD = "min_vqslod";
public static final String TRANCHE_FILTER_NAME = "filter_name";
public static final String TRANCHE_MODEL = "model";

public static final List<String> COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, CALL_GT, CALL_GQ, CALL_RGQ, QUALapprox, AS_QUALapprox, CALL_PL);//, AS_VarDP);
public static final List<String> ARRAY_COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, CALL_GT, CALL_GQ);

Expand All @@ -79,6 +85,7 @@ public class SchemaUtils {

public static final List<String> SAMPLE_FIELDS = Arrays.asList(SchemaUtils.SAMPLE_NAME_FIELD_NAME, SchemaUtils.SAMPLE_ID_FIELD_NAME);
public static final List<String> YNG_FIELDS = Arrays.asList(FILTER_SET_NAME, LOCATION_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, VQSLOD, YNG_STATUS);
public static final List<String> TRANCHE_FIELDS = Arrays.asList(TARGET_TRUTH_SENSITIVITY, MIN_VQSLOD, TRANCHE_FILTER_NAME, TRANCHE_MODEL);


public static final List<String> PET_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_ID_FIELD_NAME, STATE_FIELD_NAME);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
package org.broadinstitute.hellbender.tools.variantdb.nextgen;

import htsjdk.variant.vcf.VCFFilterHeaderLine;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLine;
import org.apache.avro.generic.GenericRecord;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.variantdb.CommonCode;
import org.broadinstitute.hellbender.tools.variantdb.nextgen.FilterSensitivityTools;
import org.broadinstitute.hellbender.tools.variantdb.SampleList;
import org.broadinstitute.hellbender.tools.variantdb.SchemaUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.bigquery.StorageAPIAvroReader;
import org.broadinstitute.hellbender.utils.bigquery.TableReference;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;

import java.util.*;



@CommandLineProgramProperties(
summary = "(\"ExtractCohort\") - Filter and extract variants out of big query.",
oneLineSummary = "Tool to extract variants out of big query for a subset of samples",
Expand All @@ -33,6 +42,13 @@ public class ExtractCohort extends ExtractTool {
)
private String filteringFQTableName = null;

@Argument(
fullName = "tranches-table",
doc = "Fully qualified name of the tranches table to use for cohort extraction",
optional = true
)
private String tranchesTableName = null;

@Argument(
fullName = "cohort-extract-table",
doc = "Fully qualified name of the table where the cohort data exists (already subsetted)",
Expand Down Expand Up @@ -61,14 +77,61 @@ public class ExtractCohort extends ExtractTool {
)
private boolean emitPLs = false;


@Argument(
fullName="snps-truth-sensitivity-filter-level",
doc="The truth sensitivity level at which to start filtering SNPs",
optional=true
)
private Double truthSensitivitySNPThreshold = null;

@Argument(
fullName="indels-truth-sensitivity-filter-level",
doc="The truth sensitivity level at which to start filtering INDELs",
optional=true
)
private Double truthSensitivityINDELThreshold = null;

@Advanced
@Argument(
fullName="snps-lod-score-cutoff",
doc="The VQSLOD score below which to start filtering SNPs",
optional=true)
private Double vqsLodSNPThreshold = null;

@Advanced
@Argument(
fullName="indels-lod-score-cutoff",
doc="The VQSLOD score below which to start filtering INDELs",
optional=true)
private Double vqsLodINDELThreshold = null;


@Override
protected void onStartup() {
super.onStartup();

Set<VCFHeaderLine> vqsrHeaderLines = new HashSet<>();
if (filteringFQTableName != null) {
FilterSensitivityTools.validateFilteringCutoffs(truthSensitivitySNPThreshold, truthSensitivityINDELThreshold, vqsLodSNPThreshold, vqsLodINDELThreshold, tranchesTableName);
Map<String, Map<Double, Double>> trancheMaps = FilterSensitivityTools.getTrancheMaps(filterSetName, tranchesTableName, projectID);

if (vqsLodSNPThreshold != null) { // we already have vqslod thresholds directly
vqsrHeaderLines.add(FilterSensitivityTools.getVqsLodHeader(vqsLodSNPThreshold, GATKVCFConstants.SNP));
vqsrHeaderLines.add(FilterSensitivityTools.getVqsLodHeader(vqsLodINDELThreshold, GATKVCFConstants.INDEL));
} else { // using sensitivity threshold inputs; need to convert these to vqslod thresholds
vqsLodSNPThreshold = FilterSensitivityTools.getVqslodThreshold(trancheMaps.get(GATKVCFConstants.SNP), truthSensitivitySNPThreshold, GATKVCFConstants.SNP);
vqsLodINDELThreshold = FilterSensitivityTools.getVqslodThreshold(trancheMaps.get(GATKVCFConstants.INDEL), truthSensitivityINDELThreshold, GATKVCFConstants.INDEL);
// set headers
vqsrHeaderLines.add(FilterSensitivityTools.getTruthSensitivityHeader(truthSensitivitySNPThreshold, vqsLodSNPThreshold, GATKVCFConstants.SNP));
vqsrHeaderLines.add(FilterSensitivityTools.getTruthSensitivityHeader(truthSensitivityINDELThreshold, vqsLodINDELThreshold, GATKVCFConstants.INDEL));
}
}

SampleList sampleList = new SampleList(sampleTableName, sampleFileName, projectID, printDebugInformation);
Set<String> sampleNames = new HashSet<>(sampleList.getSampleNames());

VCFHeader header = CommonCode.generateVcfHeader(sampleNames, reference.getSequenceDictionary());
VCFHeader header = CommonCode.generateVcfHeader(sampleNames, reference.getSequenceDictionary(), vqsrHeaderLines);

final List<SimpleInterval> traversalIntervals = getTraversalIntervals();

Expand All @@ -82,7 +145,6 @@ protected void onStartup() {
throw new UserException("min-location and max-location should not be used together with intervals (-L).");
}


engine = new ExtractCohortEngine(
projectID,
vcfWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.broadinstitute.hellbender.engine.ProgressMeter;
import org.broadinstitute.hellbender.engine.ReferenceDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.variantdb.CommonCode;
import org.broadinstitute.hellbender.tools.variantdb.SchemaUtils;
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
Expand Down Expand Up @@ -46,8 +47,8 @@ public class ExtractCohortEngine {
private final Long maxLocation;
private final TableReference filteringTableRef;
private final ReferenceDataSource refSource;
private double vqsLodSNPThreshold = 0;
private double vqsLodINDELThreshold = 0;
private Double vqsLodSNPThreshold;
private Double vqsLodINDELThreshold;

private final ProgressMeter progressMeter;
private final String projectID;
Expand Down Expand Up @@ -88,8 +89,8 @@ public ExtractCohortEngine(final String projectID,
final String filteringTableName,
final int localSortMaxRecordsInRam,
final boolean printDebugInformation,
final double vqsLodSNPThreshold,
final double vqsLodINDELThreshold,
final Double vqsLodSNPThreshold,
final Double vqsLodINDELThreshold,
final ProgressMeter progressMeter,
final ExtractCohort.QueryMode queryMode,
final String filterSetName,
Expand Down Expand Up @@ -141,6 +142,12 @@ public void traverse() {
boolean noFilteringRequested = (filteringTableRef == null);

if (!noFilteringRequested) {
// ensure vqslod filters are defined. this really shouldn't ever happen, said the engineer.
if (vqsLodSNPThreshold == null || vqsLodINDELThreshold == null) {
throw new UserException("Vqslod filtering thresholds for SNPs and INDELs must be defined.");
}

// get filter info (vqslod & yng values)
final String rowRestrictionWithFilterSetName = rowRestriction + " AND " + SchemaUtils.FILTER_SET_NAME + " = '" + filterSetName + "'";

final StorageAPIAvroReader filteringTableAvroReader = new StorageAPIAvroReader(filteringTableRef, rowRestrictionWithFilterSetName, projectID);
Expand Down Expand Up @@ -186,6 +193,7 @@ public void traverse() {
}
}


public SortingCollection<GenericRecord> getAvroSortingCollection(org.apache.avro.Schema schema, int localSortMaxRecordsInRam) {
final SortingCollection.Codec<GenericRecord> sortingCollectionCodec = new AvroSortingCollectionCodec(schema);
final Comparator<GenericRecord> sortingCollectionComparator = new Comparator<GenericRecord>() {
Expand Down Expand Up @@ -475,29 +483,26 @@ private VariantContext filterVariants(VariantContext mergedVC, HashMap<Allele, H
final VariantContextBuilder builder = new VariantContextBuilder(mergedVC);

builder.attribute(GATKVCFConstants.AS_VQS_LOD_KEY, relevantVqsLodMap.values().stream().map(val -> val.equals(Double.NaN) ? VCFConstants.EMPTY_INFO_FIELD : val.toString()).collect(Collectors.toList()));
builder.attribute(GATKVCFConstants.AS_YNG_STATUS_KEY, relevantYngMap.values().stream().collect(Collectors.toList()));
builder.attribute(GATKVCFConstants.AS_YNG_STATUS_KEY, new ArrayList<>(relevantYngMap.values()));

int refLength = mergedVC.getReference().length();

// if there are any Yays, the site is PASS
if (remappedYngMap.values().contains("Y")) {
if (remappedYngMap.containsValue("Y")) {
builder.passFilters();
} else if (remappedYngMap.values().contains("N")) {
} else if (remappedYngMap.containsValue("N")) {
builder.filter(GATKVCFConstants.NAY_FROM_YNG);
} else {
// if it doesn't trigger any of the filters below, we assume it passes.
builder.passFilters();
if (remappedYngMap.values().contains("G")) {
// TODO change the initial query to include the filtername from the tranches tables
if (remappedYngMap.containsValue("G")) {
Optional<Double> snpMax = relevantVqsLodMap.entrySet().stream().filter(entry -> entry.getKey().length() == refLength).map(entry -> entry.getValue().equals(Double.NaN) ? 0.0 : entry.getValue()).max(Double::compareTo);
if (snpMax.isPresent() && snpMax.get() < vqsLodSNPThreshold) {
// TODO: add in sensitivities
builder.filter(GATKVCFConstants.VQSR_TRANCHE_SNP);
builder.filter(GATKVCFConstants.VQSR_FAILURE_SNP);
}
Optional<Double> indelMax = relevantVqsLodMap.entrySet().stream().filter(entry -> entry.getKey().length() != refLength).map(entry -> entry.getValue().equals(Double.NaN) ? 0.0 : entry.getValue()).max(Double::compareTo);
if (indelMax.isPresent() && indelMax.get() < vqsLodINDELThreshold) {
// TODO: add in sensitivities
builder.filter(GATKVCFConstants.VQSR_TRANCHE_INDEL);
builder.filter(GATKVCFConstants.VQSR_FAILURE_INDEL);
}
} else {
// If VQSR dropped this site (there's no YNG or VQSLOD) then we'll filter it as a NAY.
Expand Down
Loading

0 comments on commit a418fb1

Please sign in to comment.