diff --git a/src/features/codeLensProvider.ts b/src/features/codeLensProvider.ts index 124a75ab8..3c9352cc3 100644 --- a/src/features/codeLensProvider.ts +++ b/src/features/codeLensProvider.ts @@ -118,20 +118,13 @@ export default class OmniSharpCodeLensProvider extends AbstractProvider implemen } if (node.Kind == "ClassDeclaration" && node.ChildNodes.length > 0) { - this._updateCodeLensForTestClass(bucket, fileName, node); + return this._updateCodeLensForTestClass(bucket, fileName, node); } - let testFeature = node.Features.find(value => (value.Name == 'XunitTestMethod' || value.Name == 'NUnitTestMethod' || value.Name == 'MSTestMethod')); + let featureAndFramework = this._getTestFeatureAndFramework(node); + let testFeature = featureAndFramework.Feature; + let testFrameworkName = featureAndFramework.Framework; if (testFeature) { - // this test method has a test feature - let testFrameworkName = 'xunit'; - if (testFeature.Name == 'NUnitTestMethod') { - testFrameworkName = 'nunit'; - } - else if (testFeature.Name == 'MSTestMethod') { - testFrameworkName = 'mstest'; - } - bucket.push(new vscode.CodeLens( toRange(node.Location), { title: "run test", command: 'dotnet.test.run', arguments: [testFeature.Data, fileName, testFrameworkName] })); @@ -143,39 +136,50 @@ export default class OmniSharpCodeLensProvider extends AbstractProvider implemen } private _updateCodeLensForTestClass(bucket: vscode.CodeLens[], fileName: string, node: protocol.Node) { - //if the class doesnot contain any method then return + // if the class doesnot contain any method then return if (!node.ChildNodes.find(value => (value.Kind == "MethodDeclaration"))) { return; } - let testMethodsInClass = new Array(); + let testMethods = new Array(); let testFrameworkName: string = null; for (let child of node.ChildNodes) { - if (child.Kind == "MethodDeclaration") { - let testFeature = child.Features.find(value => (value.Name == 'XunitTestMethod' || value.Name == 'NUnitTestMethod' || value.Name == 'MSTestMethod')); - if (testFeature) { - // this test method has a test feature - if (testFrameworkName == null) { - testFrameworkName = 'xunit'; - if (testFeature.Name == 'NUnitTestMethod') { - testFrameworkName = 'nunit'; - } - else if (testFeature.Name == 'MSTestMethod') { - testFrameworkName = 'mstest'; - } - } - testMethodsInClass.push(testFeature.Data); + let featureAndFramework = this._getTestFeatureAndFramework(child); + let testFeature = featureAndFramework.Feature; + if (testFeature) { + // this test method has a test feature + if (!testFrameworkName) { + testFrameworkName = featureAndFramework.Framework; } + + testMethods.push(testFeature.Data); } } - if (testMethodsInClass.length) { + if (testMethods.length) { bucket.push(new vscode.CodeLens( toRange(node.Location), - { title: "run all test", command: 'dotnet.classTests.run', arguments: [testMethodsInClass, fileName, testFrameworkName] })); + { title: "run all tests", command: 'dotnet.classTests.run', arguments: [testMethods, fileName, testFrameworkName] })); bucket.push(new vscode.CodeLens( toRange(node.Location), - { title: "debug all test", command: 'dotnet.classTests.debug', arguments: [testMethodsInClass, fileName, testFrameworkName] })); + { title: "debug all tests", command: 'dotnet.classTests.debug', arguments: [testMethods, fileName, testFrameworkName] })); + } + } + + private _getTestFeatureAndFramework(node: protocol.Node) { + let testFeature = node.Features.find(value => (value.Name == 'XunitTestMethod' || value.Name == 'NUnitTestMethod' || value.Name == 'MSTestMethod')); + if (testFeature) { + let testFrameworkName = 'xunit'; + if (testFeature.Name == 'NUnitTestMethod') { + testFrameworkName = 'nunit'; + } + else if (testFeature.Name == 'MSTestMethod') { + testFrameworkName = 'mstest'; + } + + return { Feature: testFeature, Framework: testFrameworkName }; } + + return { Feature: null, Framework: null }; } }