Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
fix: test asset distribution to include all tags on test/train split (#…
Browse files Browse the repository at this point in the history
…823)

* fix: test asset distribution to include all tags on test/train split

The test asset may not included all tags when export with test/train split option in current venison (2.1.0).

* Extract the same split logic into helper function

* Formatting

* Inverting if statement
  • Loading branch information
hermanho authored and JacopoMangiavacchi committed Aug 23, 2019
1 parent c0201ca commit 9d64f4a
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 41 deletions.
34 changes: 31 additions & 3 deletions src/providers/export/cntk.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,37 @@ describe("CNTK Export Provider", () => {

const assetsToExport = await getAssetsSpy.mock.results[0].value;
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);

const trainArray = [];
const testArray = [];
const tagsAssestList: {
[index: string]: {
assetSet: Set<string>,
testArray: string[],
trainArray: string[],
},
} = {};
testProject.tags.forEach((tag) =>
tagsAssestList[tag.name] = {
assetSet: new Set(), testArray: [],
trainArray: [],
});
assetsToExport.forEach((assetMetadata) => {
assetMetadata.regions.forEach((region) => {
region.tags.forEach((tagName) => {
if (tagsAssestList[tagName]) {
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
}
});
});
});

for (const tagKey of Object.keys(tagsAssestList)) {
const assetSet = tagsAssestList[tagKey].assetSet;
const testCount = Math.ceil(assetSet.size * testSplit);
testArray.push(...Array.from(assetSet).slice(0, testCount));
trainArray.push(...Array.from(assetSet).slice(testCount, assetSet.size));
}

const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
Expand Down
11 changes: 8 additions & 3 deletions src/providers/export/cntk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ExportProvider, IExportResults } from "./exportProvider";
import { IAssetMetadata, IExportProviderOptions, IProject } from "../../models/applicationState";
import HtmlFileReader from "../../common/htmlFileReader";
import Guard from "../../common/guard";
import { splitTestAsset } from "./testAssetsSplitHelper";

enum ExportSplit {
Test,
Expand Down Expand Up @@ -33,13 +34,17 @@ export class CntkExportProvider extends ExportProvider<ICntkExportProviderOption
public async export(): Promise<IExportResults> {
await this.createFolderStructure();
const assetsToExport = await this.getAssetsForExport();
const testAssets: string[] = [];

const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
if (testSplit > 0 && testSplit <= 1) {
const splittedAssets = splitTestAsset(assetsToExport, this.project.tags, testSplit);
testAssets.push(...splittedAssets);
}

const results = await assetsToExport.mapAsync(async (assetMetadata) => {
try {
const exportSplit = testArray.find((am) => am.asset.id === assetMetadata.asset.id)
const exportSplit = testAssets.find((am) => am === assetMetadata.asset.id)
? ExportSplit.Test
: ExportSplit.Train;

Expand Down
67 changes: 56 additions & 11 deletions src/providers/export/pascalVOC.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ describe("PascalVOC Json Export Provider", () => {
beforeEach(() => {
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
const mockTag = MockFactory.createTestTag();
const mockTag1 = MockFactory.createTestTag("1");
const mockTag2 = MockFactory.createTestTag("2");
const mockTag = Number(asset.id.split("-")[1]) > 7 ? mockTag1 : mockTag2;
const mockRegion1 = MockFactory.createTestRegion("region-1", [mockTag.name]);
const mockRegion2 = MockFactory.createTestRegion("region-2", [mockTag.name]);

Expand Down Expand Up @@ -352,27 +354,70 @@ describe("PascalVOC Json Export Provider", () => {
};

const testProject = { ...baseTestProject };
const testAssets = MockFactory.createTestAssets(10, 0);
const testAssets = MockFactory.createTestAssets(13, 0);
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
testProject.tags = [MockFactory.createTestTag("1")];
testProject.tags = MockFactory.createTestTags(3);

const exportProvider = new PascalVOCExportProvider(testProject, options);
const getAssetsSpy = jest.spyOn(exportProvider, "getAssetsForExport");

await exportProvider.export();

const storageProviderMock = LocalFileSystemProxy as any;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];

const valDataIndex = writeTextFileCalls
const valDataIndex1 = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"));
const trainDataIndex = writeTextFileCalls
const trainDataIndex1 = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"));

const expectedTrainCount = (testTrainSplit / 100) * testAssets.length;
const expectedTestCount = ((100 - testTrainSplit) / 100) * testAssets.length;

expect(writeTextFileCalls[valDataIndex][1].split("\n")).toHaveLength(expectedTestCount);
expect(writeTextFileCalls[trainDataIndex][1].split("\n")).toHaveLength(expectedTrainCount);
const valDataIndex2 = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_val.txt"));
const trainDataIndex2 = writeTextFileCalls
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_train.txt"));

const assetsToExport = await getAssetsSpy.mock.results[0].value;
const trainArray = [];
const testArray = [];
const tagsAssestList: {
[index: string]: {
assetSet: Set<string>,
testArray: string[],
trainArray: string[],
},
} = {};
testProject.tags.forEach((tag) =>
tagsAssestList[tag.name] = {
assetSet: new Set(), testArray: [],
trainArray: [],
});
assetsToExport.forEach((assetMetadata) => {
assetMetadata.regions.forEach((region) => {
region.tags.forEach((tagName) => {
if (tagsAssestList[tagName]) {
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
}
});
});
});

for (const tagKey of Object.keys(tagsAssestList)) {
const assetSet = tagsAssestList[tagKey].assetSet;
const testCount = Math.ceil(((100 - testTrainSplit) / 100) * assetSet.size);
tagsAssestList[tagKey].testArray = Array.from(assetSet).slice(0, testCount);
tagsAssestList[tagKey].trainArray = Array.from(assetSet).slice(testCount, assetSet.size);
testArray.push(...tagsAssestList[tagKey].testArray);
trainArray.push(...tagsAssestList[tagKey].trainArray);
}

expect(writeTextFileCalls[valDataIndex1][1].split(/\r?\n/).filter((line) =>
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].testArray.length);
expect(writeTextFileCalls[trainDataIndex1][1].split(/\r?\n/).filter((line) =>
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].trainArray.length);
expect(writeTextFileCalls[valDataIndex2][1].split(/\r?\n/).filter((line) =>
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].testArray.length);
expect(writeTextFileCalls[trainDataIndex2][1].split(/\r?\n/).filter((line) =>
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].trainArray.length);
}

it("Correctly generated files based on 50/50 test / train split", async () => {
Expand Down
67 changes: 43 additions & 24 deletions src/providers/export/pascalVOC.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import HtmlFileReader from "../../common/htmlFileReader";
import { itemTemplate, annotationTemplate, objectTemplate } from "./pascalVOC/pascalVOCTemplates";
import { interpolate } from "../../common/strings";
import os from "os";
import { splitTestAsset } from "./testAssetsSplitHelper";

interface IObjectInfo {
name: string;
Expand Down Expand Up @@ -253,40 +254,58 @@ export class PascalVOCExportProvider extends ExportProvider<IPascalVOCExportProv
}
});

// Save ImageSets
await tags.forEachAsync(async (tag) => {
const tagInstances = tagUsage.get(tag.name) || 0;
if (!exportUnassignedTags && tagInstances === 0) {
return;
}
if (testSplit > 0 && testSplit <= 1) {
const tags = this.project.tags;
const testAssets: string[] = splitTestAsset(allAssets, tags, testSplit);

const assetList = [];
assetUsage.forEach((tags, assetName) => {
if (tags.has(tag.name)) {
assetList.push(`${assetName} 1`);
} else {
assetList.push(`${assetName} -1`);
await tags.forEachAsync(async (tag) => {
const tagInstances = tagUsage.get(tag.name) || 0;
if (!exportUnassignedTags && tagInstances === 0) {
return;
}
});

if (testSplit > 0 && testSplit <= 1) {
// Split in Test and Train sets
const totalAssets = assetUsage.size;
const testCount = Math.ceil(totalAssets * testSplit);

const testArray = assetList.slice(0, testCount);
const trainArray = assetList.slice(testCount, totalAssets);
const testArray = [];
const trainArray = [];
assetUsage.forEach((tags, assetName) => {
let assetString = "";
if (tags.has(tag.name)) {
assetString = `${assetName} 1`;
} else {
assetString = `${assetName} -1`;
}
if (testAssets.find((am) => am === assetName)) {
testArray.push(assetString);
} else {
trainArray.push(assetString);
}
});

const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`;
await this.storageProvider.writeText(testImageSetFileName, testArray.join(os.EOL));

const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`;
await this.storageProvider.writeText(trainImageSetFileName, trainArray.join(os.EOL));
});
} else {

// Save ImageSets
await tags.forEachAsync(async (tag) => {
const tagInstances = tagUsage.get(tag.name) || 0;
if (!exportUnassignedTags && tagInstances === 0) {
return;
}

const assetList = [];
assetUsage.forEach((tags, assetName) => {
if (tags.has(tag.name)) {
assetList.push(`${assetName} 1`);
} else {
assetList.push(`${assetName} -1`);
}
});

} else {
const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`;
await this.storageProvider.writeText(imageSetFileName, assetList.join(os.EOL));
}
});
});
}
}
}
61 changes: 61 additions & 0 deletions src/providers/export/testAssetsSplitHelper.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import _ from "lodash";
import {
IAssetMetadata, AssetState, IRegion,
RegionType, IPoint, IExportProviderOptions,
} from "../../models/applicationState";
import MockFactory from "../../common/mockFactory";
import { splitTestAsset } from "./testAssetsSplitHelper";
import { appInfo } from "../../common/appInfo";

describe("splitTestAsset Helper tests", () => {

describe("Test Train Splits", () => {
async function testTestTrainSplit(testTrainSplit: number): Promise<void> {
const assetArray = MockFactory.createTestAssets(13, 0);
const tags = MockFactory.createTestTags(2);
assetArray.forEach((asset) => asset.state = AssetState.Tagged);

const testSplit = (100 - testTrainSplit) / 100;
const testCount = Math.ceil(testSplit * assetArray.length);

const assetMetadatas = assetArray.map((asset, i) =>
MockFactory.createTestAssetMetadata(asset,
i < (assetArray.length - testCount) ?
[MockFactory.createTestRegion("Region" + i, [tags[0].name])] :
[MockFactory.createTestRegion("Region" + i, [tags[1].name])]));
const testAssetsNames = splitTestAsset(assetMetadatas, tags, testSplit);

const trainAssetsArray = assetMetadatas.filter((assetMetadata) =>
testAssetsNames.indexOf(assetMetadata.asset.name) < 0);
const testAssetsArray = assetMetadatas.filter((assetMetadata) =>
testAssetsNames.indexOf(assetMetadata.asset.name) >= 0);

const expectedTestCount = Math.ceil(testSplit * testCount) +
Math.ceil(testSplit * (assetArray.length - testCount));
expect(testAssetsNames).toHaveLength(expectedTestCount);
expect(trainAssetsArray.length + testAssetsArray.length).toEqual(assetMetadatas.length);
expect(testAssetsArray).toHaveLength(expectedTestCount);

expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[0].name).length)
.toBeGreaterThan(0);
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[1].name).length)
.toBeGreaterThan(0);
}

it("Correctly generated files based on 50/50 test / train split", async () => {
await testTestTrainSplit(50);
});

it("Correctly generated files based on 60/40 test / train split", async () => {
await testTestTrainSplit(60);
});

it("Correctly generated files based on 80/20 test / train split", async () => {
await testTestTrainSplit(80);
});

it("Correctly generated files based on 90/10 test / train split", async () => {
await testTestTrainSplit(90);
});
});
});
30 changes: 30 additions & 0 deletions src/providers/export/testAssetsSplitHelper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { IAssetMetadata, ITag } from "../../models/applicationState";

/**
* A helper function to split train and test assets
* @param template String containing variables
* @param params Params containing substitution values
*/
export function splitTestAsset(allAssets: IAssetMetadata[], tags: ITag[], testSplitRatio: number): string[] {
if (testSplitRatio <= 0 || testSplitRatio > 1) { return []; }

const testAssets: string[] = [];
const tagsAssetDict: { [index: string]: { assetList: Set<string> } } = {};
tags.forEach((tag) => tagsAssetDict[tag.name] = { assetList: new Set() });
allAssets.forEach((assetMetadata) => {
assetMetadata.regions.forEach((region) => {
region.tags.forEach((tagName) => {
if (tagsAssetDict[tagName]) {
tagsAssetDict[tagName].assetList.add(assetMetadata.asset.name);
}
});
});
});

for (const tagKey of Object.keys(tagsAssetDict)) {
const assetList = tagsAssetDict[tagKey].assetList;
const testCount = Math.ceil(assetList.size * testSplitRatio);
testAssets.push(...Array.from(assetList).slice(0, testCount));
}
return testAssets;
}

0 comments on commit 9d64f4a

Please sign in to comment.