diff --git a/ghcide/src/Development/IDE/GHC/Compat.hs b/ghcide/src/Development/IDE/GHC/Compat.hs index 8091bdb9c1..1a72edba53 100644 --- a/ghcide/src/Development/IDE/GHC/Compat.hs +++ b/ghcide/src/Development/IDE/GHC/Compat.hs @@ -58,7 +58,7 @@ module Development.IDE.GHC.Compat( applyPluginsParsedResultAction, module Compat.HieTypes, module Compat.HieUtils, - + dropForAll ) where #if MIN_GHC_API_VERSION(8,10,0) @@ -283,3 +283,12 @@ pattern ExposePackage s a mr <- DynFlags.ExposePackage s a _ mr #else pattern ExposePackage s a mr = DynFlags.ExposePackage s a mr #endif + +-- | Take AST representation of type signature and drop `forall` part from it (if any), returning just type's body +dropForAll :: LHsType pass -> LHsType pass +#if MIN_GHC_API_VERSION(8,10,0) +dropForAll = snd . GHC.splitLHsForAllTyInvis +#else +dropForAll = snd . GHC.splitLHsForAllTy +#endif + diff --git a/ghcide/src/Development/IDE/Plugin/CodeAction.hs b/ghcide/src/Development/IDE/Plugin/CodeAction.hs index 3c1a73e752..9b3bcd5703 100644 --- a/ghcide/src/Development/IDE/Plugin/CodeAction.hs +++ b/ghcide/src/Development/IDE/Plugin/CodeAction.hs @@ -803,12 +803,13 @@ suggestFunctionConstraint ParsedModule{pm_parsed_source = L _ HsModule{hsmodDecl | L _ (SigD _ (TypeSig _ identifiers (HsWC _ (HsIB _ locatedType)))) <- hsmodDecls , any (`isSameName` T.unpack typeSignatureName) $ fmap unLoc identifiers ] - srcSpanToRange $ case splitLHsQualTy locatedType of + let typeBody = dropForAll locatedType + srcSpanToRange $ case splitLHsQualTy typeBody of (L contextSrcSpan _ , _) -> if isGoodSrcSpan contextSrcSpan then contextSrcSpan -- The type signature has explicit context - else -- No explicit context, return SrcSpan at the start of type sig where we can write context - let start = srcSpanStart $ getLoc locatedType in mkSrcSpan start start + else -- No explicit context, return SrcSpan at the start of type (after a potential `forall`) + let start = srcSpanStart $ getLoc typeBody in mkSrcSpan start start isSameName :: IdP GhcPs -> String -> Bool isSameName x name = showSDocUnsafe (ppr x) == name diff --git a/ghcide/test/exe/Main.hs b/ghcide/test/exe/Main.hs index 7173bbb69d..bfff9cac4f 100644 --- a/ghcide/test/exe/Main.hs +++ b/ghcide/test/exe/Main.hs @@ -1934,6 +1934,28 @@ addFunctionConstraintTests = let , "eq x y = x == y" ] + missingConstraintWithForAllSourceCode :: T.Text -> T.Text + missingConstraintWithForAllSourceCode constraint = + T.unlines + [ "{-# LANGUAGE ExplicitForAll #-}" + , "module Testing where" + , "" + , "eq :: forall a. " <> constraint <> "a -> a -> Bool" + , "eq x y = x == y" + ] + + incompleteConstraintWithForAllSourceCode :: T.Text -> T.Text + incompleteConstraintWithForAllSourceCode constraint = + T.unlines + [ "{-# LANGUAGE ExplicitForAll #-}" + , "module Testing where" + , "" + , "data Pair a b = Pair a b" + , "" + , "eq :: " <> constraint <> " => Pair a b -> Pair a b -> Bool" + , "eq (Pair x y) (Pair x' y') = x == x' && y == y'" + ] + incompleteConstraintSourceCode :: T.Text -> T.Text incompleteConstraintSourceCode constraint = T.unlines @@ -1978,8 +2000,8 @@ addFunctionConstraintTests = let , "eq (Pair x y) (Pair x' y') = x == x' && y == y'" ] - check :: T.Text -> T.Text -> T.Text -> TestTree - check actionTitle originalCode expectedCode = testSession (T.unpack actionTitle) $ do + check :: String -> T.Text -> T.Text -> T.Text -> TestTree + check testName actionTitle originalCode expectedCode = testSession testName $ do doc <- createDoc "Testing.hs" "haskell" originalCode _ <- waitForDiagnostics actionsOrCommands <- getCodeActions doc (Range (Position 6 0) (Position 6 maxBound)) @@ -1990,22 +2012,37 @@ addFunctionConstraintTests = let in testGroup "add function constraint" [ check + "no preexisting constraint" "Add `Eq a` to the context of the type signature for `eq`" (missingConstraintSourceCode "") (missingConstraintSourceCode "Eq a => ") , check + "no preexisting constraint, with forall" + "Add `Eq a` to the context of the type signature for `eq`" + (missingConstraintWithForAllSourceCode "") + (missingConstraintWithForAllSourceCode "Eq a => ") + , check + "preexisting constraint, no parenthesis" "Add `Eq b` to the context of the type signature for `eq`" (incompleteConstraintSourceCode "Eq a") (incompleteConstraintSourceCode "(Eq a, Eq b)") , check + "preexisting constraints in parenthesis" "Add `Eq c` to the context of the type signature for `eq`" (incompleteConstraintSourceCode2 "(Eq a, Eq b)") (incompleteConstraintSourceCode2 "(Eq a, Eq b, Eq c)") + , check + "preexisting constraints with forall" + "Add `Eq b` to the context of the type signature for `eq`" + (incompleteConstraintWithForAllSourceCode "Eq a") + (incompleteConstraintWithForAllSourceCode "(Eq a, Eq b)") , check + "preexisting constraint, with extra spaces in context" "Add `Eq b` to the context of the type signature for `eq`" (incompleteConstraintSourceCodeWithExtraCharsInContext "( Eq a )") (incompleteConstraintSourceCodeWithExtraCharsInContext "(Eq a, Eq b)") , check + "preexisting constraint, with newlines in type signature" "Add `Eq b` to the context of the type signature for `eq`" (incompleteConstraintSourceCodeWithNewlinesInTypeSignature "(Eq a)") (incompleteConstraintSourceCodeWithNewlinesInTypeSignature "(Eq a, Eq b)")