Skip to content

Commit

Permalink
fix(nlu): contextual L0 classification
Browse files Browse the repository at this point in the history
  • Loading branch information
slvnperron committed Jun 26, 2019
1 parent b972d69 commit d521ff9
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions modules/nlu/src/backend/pipelines/intents/svm_classifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ export default class SVMClassifier {
return []
}

if (!includedContexts.length) {
includedContexts = ['global']
}

const input = tokens.join(' ')

const l0Vec = await getSentenceFeatures({
Expand All @@ -287,12 +291,9 @@ export default class SVMClassifier {
langProvider: this.languageProvider,
token2vec: this.token2vec
})
const l0Features = [...l0Vec, tokens.length]
const l0 = await this.l0Predictor.predict(l0Features)

if (!includedContexts.length) {
includedContexts = ['global']
}
const l0Features = [...l0Vec, tokens.length]
const l0 = await this.predictL0Contextually(l0Features, includedContexts)

try {
debugPredict('prediction request %o', { includedContexts, input })
Expand Down Expand Up @@ -350,4 +351,14 @@ export default class SVMClassifier {
throw new VError(e, `Error predicting intent for "${input}"`)
}
}

private async predictL0Contextually(
l0Features: number[],
includedContexts: string[]
): Promise<sdk.MLToolkit.SVM.Prediction[]> {
const allL0 = await this.l0Predictor.predict(l0Features)
const includedL0 = allL0.filter(c => includedContexts.includes(c.label))
const totalL0Confidence = Math.min(1, _.sumBy(includedL0, c => c.confidence))
return includedL0.map(x => ({ ...x, confidence: x.confidence / totalL0Confidence }))
}
}

0 comments on commit d521ff9

Please sign in to comment.