This repository has been archived by the owner on Dec 7, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 841
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: test asset distribution to include all tags on test/train split (#…
…823) * fix: test asset distribution to include all tags on test/train split The test asset may not included all tags when export with test/train split option in current venison (2.1.0). * Extract the same split logic into helper function * Formatting * Inverting if statement
- Loading branch information
1 parent
c0201ca
commit 9d64f4a
Showing
6 changed files
with
229 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import _ from "lodash"; | ||
import { | ||
IAssetMetadata, AssetState, IRegion, | ||
RegionType, IPoint, IExportProviderOptions, | ||
} from "../../models/applicationState"; | ||
import MockFactory from "../../common/mockFactory"; | ||
import { splitTestAsset } from "./testAssetsSplitHelper"; | ||
import { appInfo } from "../../common/appInfo"; | ||
|
||
describe("splitTestAsset Helper tests", () => { | ||
|
||
describe("Test Train Splits", () => { | ||
async function testTestTrainSplit(testTrainSplit: number): Promise<void> { | ||
const assetArray = MockFactory.createTestAssets(13, 0); | ||
const tags = MockFactory.createTestTags(2); | ||
assetArray.forEach((asset) => asset.state = AssetState.Tagged); | ||
|
||
const testSplit = (100 - testTrainSplit) / 100; | ||
const testCount = Math.ceil(testSplit * assetArray.length); | ||
|
||
const assetMetadatas = assetArray.map((asset, i) => | ||
MockFactory.createTestAssetMetadata(asset, | ||
i < (assetArray.length - testCount) ? | ||
[MockFactory.createTestRegion("Region" + i, [tags[0].name])] : | ||
[MockFactory.createTestRegion("Region" + i, [tags[1].name])])); | ||
const testAssetsNames = splitTestAsset(assetMetadatas, tags, testSplit); | ||
|
||
const trainAssetsArray = assetMetadatas.filter((assetMetadata) => | ||
testAssetsNames.indexOf(assetMetadata.asset.name) < 0); | ||
const testAssetsArray = assetMetadatas.filter((assetMetadata) => | ||
testAssetsNames.indexOf(assetMetadata.asset.name) >= 0); | ||
|
||
const expectedTestCount = Math.ceil(testSplit * testCount) + | ||
Math.ceil(testSplit * (assetArray.length - testCount)); | ||
expect(testAssetsNames).toHaveLength(expectedTestCount); | ||
expect(trainAssetsArray.length + testAssetsArray.length).toEqual(assetMetadatas.length); | ||
expect(testAssetsArray).toHaveLength(expectedTestCount); | ||
|
||
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[0].name).length) | ||
.toBeGreaterThan(0); | ||
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[1].name).length) | ||
.toBeGreaterThan(0); | ||
} | ||
|
||
it("Correctly generated files based on 50/50 test / train split", async () => { | ||
await testTestTrainSplit(50); | ||
}); | ||
|
||
it("Correctly generated files based on 60/40 test / train split", async () => { | ||
await testTestTrainSplit(60); | ||
}); | ||
|
||
it("Correctly generated files based on 80/20 test / train split", async () => { | ||
await testTestTrainSplit(80); | ||
}); | ||
|
||
it("Correctly generated files based on 90/10 test / train split", async () => { | ||
await testTestTrainSplit(90); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import { IAssetMetadata, ITag } from "../../models/applicationState"; | ||
|
||
/** | ||
* A helper function to split train and test assets | ||
* @param template String containing variables | ||
* @param params Params containing substitution values | ||
*/ | ||
export function splitTestAsset(allAssets: IAssetMetadata[], tags: ITag[], testSplitRatio: number): string[] { | ||
if (testSplitRatio <= 0 || testSplitRatio > 1) { return []; } | ||
|
||
const testAssets: string[] = []; | ||
const tagsAssetDict: { [index: string]: { assetList: Set<string> } } = {}; | ||
tags.forEach((tag) => tagsAssetDict[tag.name] = { assetList: new Set() }); | ||
allAssets.forEach((assetMetadata) => { | ||
assetMetadata.regions.forEach((region) => { | ||
region.tags.forEach((tagName) => { | ||
if (tagsAssetDict[tagName]) { | ||
tagsAssetDict[tagName].assetList.add(assetMetadata.asset.name); | ||
} | ||
}); | ||
}); | ||
}); | ||
|
||
for (const tagKey of Object.keys(tagsAssetDict)) { | ||
const assetList = tagsAssetDict[tagKey].assetList; | ||
const testCount = Math.ceil(assetList.size * testSplitRatio); | ||
testAssets.push(...Array.from(assetList).slice(0, testCount)); | ||
} | ||
return testAssets; | ||
} |