diff --git a/src/Compiler/Service/FSharpParseFileResults.fs b/src/Compiler/Service/FSharpParseFileResults.fs index 427a01b3c4aa..f5d31725cb45 100644 --- a/src/Compiler/Service/FSharpParseFileResults.fs +++ b/src/Compiler/Service/FSharpParseFileResults.fs @@ -1064,3 +1064,35 @@ type FSharpParseFileResults(diagnostics: FSharpDiagnostic[], input: ParsedInput, member scope.ValidateBreakpointLocation pos = // This does not need to be run on the background thread scope.ValidateBreakpointLocationImpl pos + + member _.RangeOfReturnTypeDefinition(symbolUseStart: pos, ?skipLambdas) = + let skipLambdas = defaultArg skipLambdas true + + SyntaxTraversal.Traverse( + symbolUseStart, + input, + { new SyntaxVisitorBase<_>() with + member _.VisitExpr(_path, _traverseSynExpr, defaultTraverse, expr) = defaultTraverse expr + + override _.VisitBinding(_path, defaultTraverse, binding) = + match binding with + | SynBinding(expr = SynExpr.Lambda _) when skipLambdas -> defaultTraverse binding + | SynBinding(expr = SynExpr.DotLambda _) when skipLambdas -> defaultTraverse binding + ////I need the : before the Return Info + //| SynBinding(expr = SynExpr.Typed _) -> defaultTraverse binding + + // Dont skip manually type-annotated bindings + | SynBinding(returnInfo = Some (SynBindingReturnInfo (_, r, _, _))) -> Some r + + // Let binding + | SynBinding (trivia = { EqualsRange = Some equalsRange }; range = range) when range.Start = symbolUseStart -> + Some equalsRange.StartRange + + // Member binding + | SynBinding (headPat = SynPat.LongIdent(longDotId = SynLongIdent(id = _ :: ident :: _)) + trivia = { EqualsRange = Some equalsRange }) when ident.idRange.Start = symbolUseStart -> + Some equalsRange.StartRange + + | _ -> defaultTraverse binding + } + ) diff --git a/src/Compiler/Service/FSharpParseFileResults.fsi b/src/Compiler/Service/FSharpParseFileResults.fsi index e892c78aa89d..7232ab74cc57 100644 --- a/src/Compiler/Service/FSharpParseFileResults.fsi +++ b/src/Compiler/Service/FSharpParseFileResults.fsi @@ -95,6 +95,8 @@ type public FSharpParseFileResults = /// Indicates if any errors occurred during the parse member ParseHadErrors: bool + member RangeOfReturnTypeDefinition: symbolUseStart: pos * ?skipLambdas: bool -> range option + internal new: diagnostics: FSharpDiagnostic[] * input: ParsedInput * parseHadErrors: bool * dependencyFiles: string[] -> FSharpParseFileResults diff --git a/vsintegration/src/FSharp.Editor/Refactor/RemoveExplicitReturnType.fs b/vsintegration/src/FSharp.Editor/Refactor/RemoveExplicitReturnType.fs index 3e7edf612c88..1ad3aa39495f 100644 --- a/vsintegration/src/FSharp.Editor/Refactor/RemoveExplicitReturnType.fs +++ b/vsintegration/src/FSharp.Editor/Refactor/RemoveExplicitReturnType.fs @@ -25,38 +25,6 @@ open InternalOptionBuilder type internal RemoveExplicitReturnType [] () = inherit CodeRefactoringProvider() - static member RangeOfReturnTypeDefinition(input: ParsedInput, symbolUseStart: pos, ?skipLambdas) = - let skipLambdas = defaultArg skipLambdas true - - SyntaxTraversal.Traverse( - symbolUseStart, - input, - { new SyntaxVisitorBase<_>() with - member _.VisitExpr(_path, _traverseSynExpr, defaultTraverse, expr) = defaultTraverse expr - - override _.VisitBinding(_path, defaultTraverse, binding) = - match binding with - | SynBinding(expr = SynExpr.Lambda _) when skipLambdas -> defaultTraverse binding - | SynBinding(expr = SynExpr.DotLambda _) when skipLambdas -> defaultTraverse binding - ////I need the : before the Return Info - //| SynBinding(expr = SynExpr.Typed _) -> defaultTraverse binding - - // Dont skip manually type-annotated bindings - | SynBinding(returnInfo = Some (SynBindingReturnInfo (_, r, _, _))) -> Some r - - // Let binding - | SynBinding (trivia = { EqualsRange = Some equalsRange }; range = range) when range.Start = symbolUseStart -> - Some equalsRange.StartRange - - // Member binding - | SynBinding (headPat = SynPat.LongIdent(longDotId = SynLongIdent(id = _ :: ident :: _)) - trivia = { EqualsRange = Some equalsRange }) when ident.idRange.Start = symbolUseStart -> - Some equalsRange.StartRange - - | _ -> defaultTraverse binding - } - ) - static member RangeIncludingColon(range: TextSpan, sourceText: SourceText) = let lineUntilType = TextSpan.FromBounds(0, range.Start) @@ -71,7 +39,7 @@ type internal RemoveExplicitReturnType [] () = (funcOrValue: FSharpMemberOrFunctionOrValue) = let returnTypeHintAlreadyPresent = - RemoveExplicitReturnType.RangeOfReturnTypeDefinition(parseFileResults.ParseTree, symbolUse.Range.Start, false) + parseFileResults.RangeOfReturnTypeDefinition(symbolUse.Range.Start, false) |> Option.isSome let isLambdaIfFunction = @@ -97,7 +65,7 @@ type internal RemoveExplicitReturnType [] () = let getChangedText (sourceText: SourceText) = let newSourceText = - RemoveExplicitReturnType.RangeOfReturnTypeDefinition(parseFileResults.ParseTree, symbolUse.Range.Start, false) + parseFileResults.RangeOfReturnTypeDefinition(symbolUse.Range.Start, false) |> Option.map (fun range -> RoslynHelpers.FSharpRangeToTextSpan(sourceText, range)) |> Option.map (fun textSpan -> RemoveExplicitReturnType.RangeIncludingColon(textSpan, sourceText)) |> Option.map (fun textSpan -> sourceText.Replace(textSpan, "")) diff --git a/vsintegration/tests/FSharp.Editor.Tests/Refactors/RefactorTestFramework.fs b/vsintegration/tests/FSharp.Editor.Tests/Refactors/RefactorTestFramework.fs index 125832743f6f..8d940869b6d7 100644 --- a/vsintegration/tests/FSharp.Editor.Tests/Refactors/RefactorTestFramework.fs +++ b/vsintegration/tests/FSharp.Editor.Tests/Refactors/RefactorTestFramework.fs @@ -64,8 +64,7 @@ let TryGetRangeOfExplicitReturnType (symbolName: string) (document: Document) ct let range = symbol - |> Option.bind (fun sym -> - RemoveExplicitReturnType.RangeOfReturnTypeDefinition(parseFileResults.ParseTree, sym.DeclarationLocation.Start, false)) + |> Option.bind (fun sym -> parseFileResults.RangeOfReturnTypeDefinition(sym.DeclarationLocation.Start, false)) return range }