Skip to content

Commit

Permalink
feat: allow passing a callback to cross-validation methods
Browse files Browse the repository at this point in the history
* feat: allow passing a callback to cross-validation methods

* Update readme with usage example with callback

* readme: fix link

* fix grammar error in readme
  • Loading branch information
stropitek authored and targos committed Nov 4, 2017
1 parent 35ad6fd commit 32501e2
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 56 deletions.
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ Cross-validation methods:

[API documentation](https://mljs.github.io/cross-validation/).

A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section, but you could also use your own. Cross validations methods return a ConfusionMatrix ([https://github.com/mljs/confusion-matrix](https://github.com/mljs/confusion-matrix) that can be used to calculate metrics on your classification result.
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section, but you could also use your own. Cross validations methods return a ConfusionMatrix ([https://github.com/mljs/confusion-matrix](https://github.com/mljs/confusion-matrix)) that can be used to calculate metrics on your classification result.

## Installation
```bash
npm i -s ml-cross-validation
```

## Example
## Example using a ml classification library
```js
const crossValidation = require('ml-cross-validation');
const KNN = require('ml-knn');
Expand All @@ -29,12 +29,29 @@ const confusionMatrix = crossValidation.leaveOneOut(KNN, dataSet, labels);
const accuracy = confusionMatrix.getAccuracy();
```

## Use your own classification library with ml-cross-validation
To be used with ml-cross-validation, your classification library must implement
## Example using a classifier with its own specific API
If you have a library that does not comply with the ML Classifier conventions, you can use can use a callback to perform the classification.
The callback will take the train features and labels, and the test features. The callback shoud return the array of predicted labels.
```js
const crossValidation = require('ml-cross-validation');
const KNN = require('ml-knn');
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]];
const labels = [0, 0, 0, 1, 1, 1];
const confusionMatrix = crossValidation.leaveOneOut(dataSet, labels, function(trainFeatures, trainLabels, testFeatures) {
const knn = new KNN(trainFeatures, trainLabels);
return knn.predict(testFeatures);
});
const accuracy = confusionMatrix.getAccuracy();
```

## ML classifier API conventions
You can write your classification library so that it can be used with ml-cross-validation as described in [here](#example-using-a-ml-classification-library)
For that, your classification library must implement
- A constructor. The constructor can be passed options as a single argument.
- A `train` method. The `train` method is passed the data as a first argument and the labels as a second.
- A `predict` method. The `predict` method is passed test data and should return a predicted label.

### Example
```js
class MyClassifier {
constructor(options) {
Expand All @@ -49,6 +66,7 @@ class MyClassifier {
}
}
```
###

[npm-image]: https://img.shields.io/npm/v/ml-cross-validation.svg?style=flat-square
[npm-url]: https://npmjs.org/package/ml-cross-validation
Expand Down
48 changes: 20 additions & 28 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"eslint": "^4.1.1",
"eslint-config-cheminfo": "^1.0.0",
"eslint-plugin-no-only-tests": "^2.0.0",
"mocha": "^3.1.2",
"mocha": "^3.5.3",
"mocha-better-spec-reporter": "^3.0.2",
"should": "^11.1.1"
},
Expand Down
93 changes: 73 additions & 20 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,23 @@ const CV = {};
const combinations = require('ml-combinations');

/**
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the validation
* set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a special case
* of LPO-CV. @see leavePout
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api.
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the
* validation set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a
* special case of LPO-CV. @see leavePout
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier
* api.
* @param {Array} features - The features for all samples of the data-set
* @param {Array} labels - The classification class of all samples of the data-set
* @param {object} classifierOptions - The classifier options with which the classifier should be instantiated.
* @return {ConfusionMatrix} - The cross-validation confusion matrix
*/
CV.leaveOneOut = function (Classifier, features, labels, classifierOptions) {
if (typeof labels === 'function') {
var callback = labels;
labels = features;
features = Classifier;
return CV.leavePOut(features, labels, 1, callback);
}
return CV.leavePOut(Classifier, features, labels, classifierOptions, 1);
};

Expand All @@ -25,14 +32,21 @@ CV.leaveOneOut = function (Classifier, features, labels, classifierOptions) {
* validation set while the rest is used as the training set. This is repeated as many times as there are possible
* ways to combine p observations from the set (unordered without replacement). Be aware that for relatively small
* data-set size this can require a very large number of training and testing to do!
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api.
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier
* api.
* @param {Array} features - The features for all samples of the data-set
* @param {Array} labels - The classification class of all samples of the data-set
* @param {object} classifierOptions - The classifier options with which the classifier should be instantiated.
* @param {number} p - The size of the validation sub-samples' set
* @return {ConfusionMatrix} - The cross-validation confusion matrix
*/
CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) {
if (typeof classifierOptions === 'function') {
var callback = classifierOptions;
p = labels;
labels = features;
features = Classifier;
}
check(features, labels);
const distinct = getDistinct(labels);
const confusionMatrix = initMatrix(distinct.length, distinct.length);
Expand All @@ -50,7 +64,12 @@ CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) {
trainIdx.splice(testIdx[i], 1);
}

validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
if (callback) {
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback);
} else {
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
}

}

return new ConfusionMatrix(confusionMatrix, distinct);
Expand All @@ -68,6 +87,12 @@ CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) {
* @return {ConfusionMatrix} - The cross-validation confusion matrix
*/
CV.kFold = function (Classifier, features, labels, classifierOptions, k) {
if (typeof classifierOptions === 'function') {
var callback = classifierOptions;
k = labels;
labels = features;
features = Classifier;
}
check(features, labels);
const distinct = getDistinct(labels);
const confusionMatrix = initMatrix(distinct.length, distinct.length);
Expand Down Expand Up @@ -101,7 +126,11 @@ CV.kFold = function (Classifier, features, labels, classifierOptions, k) {
if (j !== i) trainIdx = trainIdx.concat(folds[j]);
}

validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
if (callback) {
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback);
} else {
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
}
}

return new ConfusionMatrix(confusionMatrix, distinct);
Expand All @@ -126,18 +155,7 @@ function getDistinct(arr) {
}

function validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct) {
var testFeatures = testIdx.map(function (index) {
return features[index];
});
var trainFeatures = trainIdx.map(function (index) {
return features[index];
});
var testLabels = testIdx.map(function (index) {
return labels[index];
});
var trainLabels = trainIdx.map(function (index) {
return labels[index];
});
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx);

var classifier;
if (Classifier.prototype.train) {
Expand All @@ -148,9 +166,44 @@ function validate(Classifier, features, labels, classifierOptions, testIdx, trai
}

var predictedLabels = classifier.predict(testFeatures);
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct);
}

function validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback) {
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx);
const predictedLabels = callback(trainFeatures, trainLabels, testFeatures);
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct);
}

function updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct) {

for (var i = 0; i < predictedLabels.length; i++) {
confusionMatrix[distinct.indexOf(testLabels[i])][distinct.indexOf(predictedLabels[i])]++;
const actualIdx = distinct.indexOf(testLabels[i]);
const predictedIdx = distinct.indexOf(predictedLabels[i]);
if (actualIdx < 0 || predictedIdx < 0) {
// eslint-disable-next-line no-console
console.warn(`ignore unknown predicted label ${predictedLabels[i]}`);
}
confusionMatrix[actualIdx][predictedIdx]++;
}
}


function getTrainTest(features, labels, testIdx, trainIdx) {
return {
testFeatures: testIdx.map(function (index) {
return features[index];
}),
trainFeatures: trainIdx.map(function (index) {
return features[index];
}),
testLabels: testIdx.map(function (index) {
return labels[index];
}),
trainLabels: trainIdx.map(function (index) {
return labels[index];
})
};
}

module.exports = CV;
46 changes: 43 additions & 3 deletions test/crossValidation.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
const Dummy = require('./DummyClassifier');
const CV = require('..');

var LOO = require('./data/LOO-CV');
var LPO = require('./data/LPO-CV');
var KF = require('./data/KF-CV');


describe('basic', function () {
it('basic leave-one-out cross-validation', function () {
var LOO = require('./data/LOO-CV');
for (let i = 0; i < LOO.length; i++) {
var CM = CV.leaveOneOut(Dummy, LOO[i].features, LOO[i].labels, LOO[i].classifierOptions);
CM.getMatrix().should.deepEqual(LOO[i].result.matrix);
Expand All @@ -14,7 +18,6 @@ describe('basic', function () {
});

it('basic leave-p-out cross-validation', function () {
var LPO = require('./data/LPO-CV');
for (let i = 0; i < LPO.length; i++) {
var CM = CV.leavePOut(Dummy, LPO[i].features, LPO[i].labels, LPO[i].classifierOptions, LPO[i].p);
CM.getMatrix().should.deepEqual(LPO[i].result.matrix);
Expand All @@ -23,11 +26,48 @@ describe('basic', function () {
});

it('basic k-fold cross-validation', function () {
var KF = require('./data/KF-CV');
for (let i = 0; i < KF.length; i++) {
var CM = CV.kFold(Dummy, KF[i].features, KF[i].labels, KF[i].classifierOptions, KF[i].k);
CM.getMatrix().should.deepEqual(KF[i].result.matrix);
CM.getLabels().should.deepEqual(KF[i].result.labels);
}
});
});

describe('with a callback', function () {
it('basic leave-on-out cross-validation with callback', function () {
for (let i = 0; i < LOO.length; i++) {
var CM = CV.leaveOneOut(LOO[i].features, LOO[i].labels, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(LOO[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(LOO[i].result.matrix);
CM.getLabels().should.deepEqual(LOO[i].result.labels);
}
});

it('basic leave-p-out cross-validation with callback', function () {
for (let i = 0; i < LPO.length; i++) {
var CM = CV.leavePOut(LPO[i].features, LPO[i].labels, LPO[i].p, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(LPO[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(LPO[i].result.matrix);
CM.getLabels().should.deepEqual(LPO[i].result.labels);
}
});

it('basic k-fold cross-validation with callback', function () {
for (let i = 0; i < KF.length; i++) {
var CM = CV.kFold(KF[i].features, KF[i].labels, KF[i].k, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(KF[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(KF[i].result.matrix);
CM.getLabels().should.deepEqual(KF[i].result.labels);
}
});
});

0 comments on commit 32501e2

Please sign in to comment.