@@ -77,6 +77,7 @@ module Data.HashMap.Internal
7777 -- * Difference and intersection
7878 , difference
7979 , differenceWith
80+ , differenceWithKey
8081 , intersection
8182 , intersectionWith
8283 , intersectionWithKey
@@ -1917,13 +1918,187 @@ differenceCollisions !h1 !ary1 t1 !h2 !ary2
19171918-- encountered, the combining function is applied to the values of these keys.
19181919-- If it returns 'Nothing', the element is discarded (proper set difference). If
19191920-- it returns (@'Just' y@), the element is updated with a new value @y@.
1920- differenceWith :: (Eq k , Hashable k ) => (v -> w -> Maybe v ) -> HashMap k v -> HashMap k w -> HashMap k v
1921- differenceWith f a b = foldlWithKey' go empty a
1921+ differenceWith :: Eq k => (v -> w -> Maybe v ) -> HashMap k v -> HashMap k w -> HashMap k v
1922+ differenceWith f = differenceWithKey (const f)
1923+ {-# INLINE differenceWith #-}
1924+
1925+ -- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
1926+ -- encountered, the combining function is applied to the values of these keys.
1927+ -- If it returns 'Nothing', the element is discarded (proper set difference). If
1928+ -- it returns (@'Just' y@), the element is updated with a new value @y@.
1929+ differenceWithKey :: Eq k => (k -> v -> w -> Maybe v ) -> HashMap k v -> HashMap k w -> HashMap k v
1930+ differenceWithKey f = go_differenceWithKey 0
19221931 where
1923- go m k v = case lookup k b of
1924- Nothing -> unsafeInsert k v m
1925- Just w -> maybe m (\ y -> unsafeInsert k y m) (f v w)
1926- {-# INLINABLE differenceWith #-}
1932+ go_differenceWithKey ! _s Empty _tB = Empty
1933+ go_differenceWithKey _s a Empty = a
1934+ go_differenceWithKey s a@ (Leaf hA (L kA vA)) b
1935+ = lookupCont
1936+ (\ _ -> a)
1937+ (\ vB _ -> case f kA vA vB of
1938+ Nothing -> Empty
1939+ Just v | v `ptrEq` vA -> a
1940+ | otherwise -> Leaf hA (L kA v))
1941+ hA kA s b
1942+ go_differenceWithKey _s a@ (Collision hA aryA) (Leaf hB (L kB vB))
1943+ | hA == hB = updateCollision (\ vA -> f kB vA vB) hA kB aryA a
1944+ | otherwise = a
1945+ go_differenceWithKey s a@ (BitmapIndexed bA aryA) b@ (Leaf hB _)
1946+ | bA .&. m == 0 = a
1947+ | otherwise = case A. index# aryA i of
1948+ (# ! stA # ) -> case go_differenceWithKey (nextShift s) stA b of
1949+ Empty | A. length aryA == 2
1950+ , (# l # ) <- A. index# aryA (otherOfOneOrZero i)
1951+ , isLeafOrCollision l
1952+ -> l
1953+ | otherwise
1954+ -> BitmapIndexed (bA .&. complement m) (A. delete aryA i)
1955+ stA' | isLeafOrCollision stA' && A. length aryA == 1 -> stA'
1956+ | stA `ptrEq` stA' -> a
1957+ | otherwise -> BitmapIndexed bA (A. update aryA i stA')
1958+ where
1959+ m = mask hB s
1960+ i = sparseIndex bA m
1961+ go_differenceWithKey s a@ (BitmapIndexed bA aryA) b@ (Collision hB _)
1962+ | bA .&. m == 0 = a
1963+ | otherwise =
1964+ case A. index# aryA i of
1965+ (# ! st # ) -> case go_differenceWithKey (nextShift s) st b of
1966+ Empty | A. length aryA == 2
1967+ , (# l # ) <- A. index# aryA (otherOfOneOrZero i)
1968+ , isLeafOrCollision l
1969+ -> l
1970+ | otherwise
1971+ -> BitmapIndexed (bA .&. complement m) (A. delete aryA i)
1972+ st' | isLeafOrCollision st' && A. length aryA == 1 -> st'
1973+ | st `ptrEq` st' -> a
1974+ | otherwise -> BitmapIndexed bA (A. update aryA i st')
1975+ where
1976+ m = mask hB s
1977+ i = sparseIndex bA m
1978+ go_differenceWithKey s a@ (Full aryA) b@ (Leaf hB _)
1979+ = case A. index# aryA i of
1980+ (# ! stA # ) -> case go_differenceWithKey (nextShift s) stA b of
1981+ Empty ->
1982+ let aryA' = A. delete aryA i
1983+ bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1984+ in BitmapIndexed bm aryA'
1985+ stA' | stA `ptrEq` stA' -> a
1986+ | otherwise -> Full (updateFullArray aryA i stA')
1987+ where i = index hB s
1988+ go_differenceWithKey s a@ (Full aryA) b@ (Collision hB _)
1989+ = case A. index# aryA i of
1990+ (# ! stA # ) -> case go_differenceWithKey (nextShift s) stA b of
1991+ Empty ->
1992+ let aryA' = A. delete aryA i
1993+ bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1994+ in BitmapIndexed bm aryA'
1995+ stA' | stA `ptrEq` stA' -> a
1996+ | otherwise -> Full (updateFullArray aryA i stA')
1997+ where i = index hB s
1998+ go_differenceWithKey s a@ (Collision hA _) (BitmapIndexed bB aryB)
1999+ | bB .&. m == 0 = a
2000+ | otherwise =
2001+ case A. index# aryB (sparseIndex bB m) of
2002+ (# stB # ) -> go_differenceWithKey (nextShift s) a stB
2003+ where m = mask hA s
2004+ go_differenceWithKey s a@ (Collision hA _) (Full aryB)
2005+ = case A. index# aryB (index hA s) of
2006+ (# stB # ) -> go_differenceWithKey (nextShift s) a stB
2007+ go_differenceWithKey s a@ (BitmapIndexed bA aryA) (BitmapIndexed bB aryB)
2008+ = differenceWithKey_Arrays s bA aryA a bB aryB
2009+ go_differenceWithKey s a@ (Full aryA) (BitmapIndexed bB aryB)
2010+ = differenceWithKey_Arrays s fullBitmap aryA a bB aryB
2011+ go_differenceWithKey s a@ (BitmapIndexed bA aryA) (Full aryB)
2012+ = differenceWithKey_Arrays s bA aryA a fullBitmap aryB
2013+ go_differenceWithKey s a@ (Full aryA) (Full aryB)
2014+ = differenceWithKey_Arrays s fullBitmap aryA a fullBitmap aryB
2015+ go_differenceWithKey _s a@ (Collision hA aryA) (Collision hB aryB)
2016+ = differenceWithKey_Collisions f hA aryA a hB aryB
2017+
2018+ differenceWithKey_Arrays ! s ! bA ! aryA tA ! bB ! aryB
2019+ | bA .&. bB == 0 = tA
2020+ | otherwise = runST $ do
2021+ mary <- A. new_ $ A. length aryA
2022+
2023+ -- TODO: i == popCount bResult. Not sure if that would be faster.
2024+ -- Also iA is in some relation with bA'
2025+ let go_dWKA ! i ! iA ! bA' ! bResult ! nChanges
2026+ | bA' == 0 = pure (bResult, nChanges)
2027+ | otherwise = do
2028+ ! stA <- A. indexM aryA iA
2029+ case m .&. bB of
2030+ 0 -> do
2031+ A. write mary i stA
2032+ go_dWKA (i + 1 ) (iA + 1 ) nextBA' (bResult .|. m) nChanges
2033+ _ -> do
2034+ ! stB <- A. indexM aryB (sparseIndex bB m)
2035+ case go_differenceWithKey (nextShift s) stA stB of
2036+ Empty -> go_dWKA i (iA + 1 ) nextBA' bResult (nChanges + 1 )
2037+ st -> do
2038+ A. write mary i st
2039+ let same = I # (Exts. reallyUnsafePtrEquality# st stA)
2040+ let nChanges' = nChanges + (1 - same)
2041+ go_dWKA (i + 1 ) (iA + 1 ) nextBA' (bResult .|. m) nChanges'
2042+ where
2043+ m = bA' .&. negate bA'
2044+ nextBA' = bA' .&. complement m
2045+
2046+ (bResult, nChanges) <- go_dWKA 0 0 bA 0 0
2047+ if nChanges == 0
2048+ then pure tA
2049+ else case popCount bResult of
2050+ 0 -> pure Empty
2051+ 1 -> do
2052+ l <- A. read mary 0
2053+ if isLeafOrCollision l
2054+ then pure l
2055+ else BitmapIndexed bResult <$> (A. unsafeFreeze =<< A. shrink mary 1 )
2056+ n -> bitmapIndexedOrFull bResult <$> (A. unsafeFreeze =<< A. shrink mary n)
2057+ {-# INLINE differenceWithKey #-}
2058+
2059+ -- | 'update', specialized to 'Collision' nodes.
2060+ updateCollision
2061+ :: Eq k
2062+ => (v -> Maybe v )
2063+ -> Hash
2064+ -> k
2065+ -> A. Array (Leaf k v )
2066+ -> HashMap k v
2067+ -- ^ The original Collision node which will be re-used if the array is unchanged.
2068+ -> HashMap k v
2069+ updateCollision f ! h k ! ary orig =
2070+ lookupInArrayCont
2071+ (\ _ -> orig)
2072+ (\ v i -> case f v of
2073+ Nothing | A. length ary == 2
2074+ , (# l # ) <- A. index# ary (otherOfOneOrZero i)
2075+ -> Leaf h l
2076+ | otherwise -> Collision h (A. delete ary i)
2077+ Just v' | v' `ptrEq` v -> orig
2078+ | otherwise -> Collision h (A. update ary i (L k v')))
2079+ k ary
2080+ {-# INLINABLE updateCollision #-}
2081+
2082+ -- TODO: This could be faster if we would keep track of which elements of ary2
2083+ -- we've already matched. Those could be skipped when we check the following
2084+ -- elements of ary1.
2085+ -- TODO: Return tA when the array is unchanged.
2086+ differenceWithKey_Collisions :: Eq k => (k -> v -> w -> Maybe v ) -> Word -> A. Array (Leaf k v ) -> HashMap k v -> Word -> A. Array (Leaf k w ) -> HashMap k v
2087+ differenceWithKey_Collisions f ! hA ! aryA ! tA ! hB ! aryB
2088+ | hA == hB =
2089+ let f' l@ (L kA vA) =
2090+ lookupInArrayCont
2091+ (\ _ -> Just l)
2092+ (\ vB _ -> L kA <$> f kA vA vB)
2093+ kA aryB
2094+ ary = A. mapMaybe f' aryA
2095+ in case A. length ary of
2096+ 0 -> Empty
2097+ 1 -> case A. index# ary 0 of
2098+ (# l # ) -> Leaf hA l
2099+ _ -> Collision hA ary
2100+ | otherwise = tA
2101+ {-# INLINABLE differenceWithKey_Collisions #-}
19272102
19282103-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
19292104-- map for keys existing in the second.
0 commit comments