Skip to content

Commit a7736a1

Browse files
authored
Add differenceWithKey (#542)
...and define `differenceWith` via `differenceWithKey`. Closes #364, closes #389.
1 parent 5ce9758 commit a7736a1

File tree

7 files changed

+273
-13
lines changed

7 files changed

+273
-13
lines changed

Data/HashMap/Internal.hs

Lines changed: 181 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ module Data.HashMap.Internal
7777
-- * Difference and intersection
7878
, difference
7979
, differenceWith
80+
, differenceWithKey
8081
, intersection
8182
, intersectionWith
8283
, intersectionWithKey
@@ -1917,13 +1918,187 @@ differenceCollisions !h1 !ary1 t1 !h2 !ary2
19171918
-- encountered, the combining function is applied to the values of these keys.
19181919
-- If it returns 'Nothing', the element is discarded (proper set difference). If
19191920
-- it returns (@'Just' y@), the element is updated with a new value @y@.
1920-
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1921-
differenceWith f a b = foldlWithKey' go empty a
1921+
differenceWith :: Eq k => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1922+
differenceWith f = differenceWithKey (const f)
1923+
{-# INLINE differenceWith #-}
1924+
1925+
-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
1926+
-- encountered, the combining function is applied to the values of these keys.
1927+
-- If it returns 'Nothing', the element is discarded (proper set difference). If
1928+
-- it returns (@'Just' y@), the element is updated with a new value @y@.
1929+
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1930+
differenceWithKey f = go_differenceWithKey 0
19221931
where
1923-
go m k v = case lookup k b of
1924-
Nothing -> unsafeInsert k v m
1925-
Just w -> maybe m (\y -> unsafeInsert k y m) (f v w)
1926-
{-# INLINABLE differenceWith #-}
1932+
go_differenceWithKey !_s Empty _tB = Empty
1933+
go_differenceWithKey _s a Empty = a
1934+
go_differenceWithKey s a@(Leaf hA (L kA vA)) b
1935+
= lookupCont
1936+
(\_ -> a)
1937+
(\vB _ -> case f kA vA vB of
1938+
Nothing -> Empty
1939+
Just v | v `ptrEq` vA -> a
1940+
| otherwise -> Leaf hA (L kA v))
1941+
hA kA s b
1942+
go_differenceWithKey _s a@(Collision hA aryA) (Leaf hB (L kB vB))
1943+
| hA == hB = updateCollision (\vA -> f kB vA vB) hA kB aryA a
1944+
| otherwise = a
1945+
go_differenceWithKey s a@(BitmapIndexed bA aryA) b@(Leaf hB _)
1946+
| bA .&. m == 0 = a
1947+
| otherwise = case A.index# aryA i of
1948+
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
1949+
Empty | A.length aryA == 2
1950+
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
1951+
, isLeafOrCollision l
1952+
-> l
1953+
| otherwise
1954+
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
1955+
stA' | isLeafOrCollision stA' && A.length aryA == 1 -> stA'
1956+
| stA `ptrEq` stA' -> a
1957+
| otherwise -> BitmapIndexed bA (A.update aryA i stA')
1958+
where
1959+
m = mask hB s
1960+
i = sparseIndex bA m
1961+
go_differenceWithKey s a@(BitmapIndexed bA aryA) b@(Collision hB _)
1962+
| bA .&. m == 0 = a
1963+
| otherwise =
1964+
case A.index# aryA i of
1965+
(# !st #) -> case go_differenceWithKey (nextShift s) st b of
1966+
Empty | A.length aryA == 2
1967+
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
1968+
, isLeafOrCollision l
1969+
-> l
1970+
| otherwise
1971+
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
1972+
st' | isLeafOrCollision st' && A.length aryA == 1 -> st'
1973+
| st `ptrEq` st' -> a
1974+
| otherwise -> BitmapIndexed bA (A.update aryA i st')
1975+
where
1976+
m = mask hB s
1977+
i = sparseIndex bA m
1978+
go_differenceWithKey s a@(Full aryA) b@(Leaf hB _)
1979+
= case A.index# aryA i of
1980+
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
1981+
Empty ->
1982+
let aryA' = A.delete aryA i
1983+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1984+
in BitmapIndexed bm aryA'
1985+
stA' | stA `ptrEq` stA' -> a
1986+
| otherwise -> Full (updateFullArray aryA i stA')
1987+
where i = index hB s
1988+
go_differenceWithKey s a@(Full aryA) b@(Collision hB _)
1989+
= case A.index# aryA i of
1990+
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
1991+
Empty ->
1992+
let aryA' = A.delete aryA i
1993+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1994+
in BitmapIndexed bm aryA'
1995+
stA' | stA `ptrEq` stA' -> a
1996+
| otherwise -> Full (updateFullArray aryA i stA')
1997+
where i = index hB s
1998+
go_differenceWithKey s a@(Collision hA _) (BitmapIndexed bB aryB)
1999+
| bB .&. m == 0 = a
2000+
| otherwise =
2001+
case A.index# aryB (sparseIndex bB m) of
2002+
(# stB #) -> go_differenceWithKey (nextShift s) a stB
2003+
where m = mask hA s
2004+
go_differenceWithKey s a@(Collision hA _) (Full aryB)
2005+
= case A.index# aryB (index hA s) of
2006+
(# stB #) -> go_differenceWithKey (nextShift s) a stB
2007+
go_differenceWithKey s a@(BitmapIndexed bA aryA) (BitmapIndexed bB aryB)
2008+
= differenceWithKey_Arrays s bA aryA a bB aryB
2009+
go_differenceWithKey s a@(Full aryA) (BitmapIndexed bB aryB)
2010+
= differenceWithKey_Arrays s fullBitmap aryA a bB aryB
2011+
go_differenceWithKey s a@(BitmapIndexed bA aryA) (Full aryB)
2012+
= differenceWithKey_Arrays s bA aryA a fullBitmap aryB
2013+
go_differenceWithKey s a@(Full aryA) (Full aryB)
2014+
= differenceWithKey_Arrays s fullBitmap aryA a fullBitmap aryB
2015+
go_differenceWithKey _s a@(Collision hA aryA) (Collision hB aryB)
2016+
= differenceWithKey_Collisions f hA aryA a hB aryB
2017+
2018+
differenceWithKey_Arrays !s !bA !aryA tA !bB !aryB
2019+
| bA .&. bB == 0 = tA
2020+
| otherwise = runST $ do
2021+
mary <- A.new_ $ A.length aryA
2022+
2023+
-- TODO: i == popCount bResult. Not sure if that would be faster.
2024+
-- Also iA is in some relation with bA'
2025+
let go_dWKA !i !iA !bA' !bResult !nChanges
2026+
| bA' == 0 = pure (bResult, nChanges)
2027+
| otherwise = do
2028+
!stA <- A.indexM aryA iA
2029+
case m .&. bB of
2030+
0 -> do
2031+
A.write mary i stA
2032+
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges
2033+
_ -> do
2034+
!stB <- A.indexM aryB (sparseIndex bB m)
2035+
case go_differenceWithKey (nextShift s) stA stB of
2036+
Empty -> go_dWKA i (iA + 1) nextBA' bResult (nChanges + 1)
2037+
st -> do
2038+
A.write mary i st
2039+
let same = I# (Exts.reallyUnsafePtrEquality# st stA)
2040+
let nChanges' = nChanges + (1 - same)
2041+
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges'
2042+
where
2043+
m = bA' .&. negate bA'
2044+
nextBA' = bA' .&. complement m
2045+
2046+
(bResult, nChanges) <- go_dWKA 0 0 bA 0 0
2047+
if nChanges == 0
2048+
then pure tA
2049+
else case popCount bResult of
2050+
0 -> pure Empty
2051+
1 -> do
2052+
l <- A.read mary 0
2053+
if isLeafOrCollision l
2054+
then pure l
2055+
else BitmapIndexed bResult <$> (A.unsafeFreeze =<< A.shrink mary 1)
2056+
n -> bitmapIndexedOrFull bResult <$> (A.unsafeFreeze =<< A.shrink mary n)
2057+
{-# INLINE differenceWithKey #-}
2058+
2059+
-- | 'update', specialized to 'Collision' nodes.
2060+
updateCollision
2061+
:: Eq k
2062+
=> (v -> Maybe v)
2063+
-> Hash
2064+
-> k
2065+
-> A.Array (Leaf k v)
2066+
-> HashMap k v
2067+
-- ^ The original Collision node which will be re-used if the array is unchanged.
2068+
-> HashMap k v
2069+
updateCollision f !h k !ary orig =
2070+
lookupInArrayCont
2071+
(\_ -> orig)
2072+
(\v i -> case f v of
2073+
Nothing | A.length ary == 2
2074+
, (# l #) <- A.index# ary (otherOfOneOrZero i)
2075+
-> Leaf h l
2076+
| otherwise -> Collision h (A.delete ary i)
2077+
Just v' | v' `ptrEq` v -> orig
2078+
| otherwise -> Collision h (A.update ary i (L k v')))
2079+
k ary
2080+
{-# INLINABLE updateCollision #-}
2081+
2082+
-- TODO: This could be faster if we would keep track of which elements of ary2
2083+
-- we've already matched. Those could be skipped when we check the following
2084+
-- elements of ary1.
2085+
-- TODO: Return tA when the array is unchanged.
2086+
differenceWithKey_Collisions :: Eq k => (k -> v -> w -> Maybe v) -> Word -> A.Array (Leaf k v) -> HashMap k v -> Word -> A.Array (Leaf k w) -> HashMap k v
2087+
differenceWithKey_Collisions f !hA !aryA !tA !hB !aryB
2088+
| hA == hB =
2089+
let f' l@(L kA vA) =
2090+
lookupInArrayCont
2091+
(\_ -> Just l)
2092+
(\vB _ -> L kA <$> f kA vA vB)
2093+
kA aryB
2094+
ary = A.mapMaybe f' aryA
2095+
in case A.length ary of
2096+
0 -> Empty
2097+
1 -> case A.index# ary 0 of
2098+
(# l #) -> Leaf hA l
2099+
_ -> Collision hA ary
2100+
| otherwise = tA
2101+
{-# INLINABLE differenceWithKey_Collisions #-}
19272102

19282103
-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
19292104
-- map for keys existing in the second.

Data/HashMap/Internal/Array.hs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ module Data.HashMap.Internal.Array
7272
, map
7373
, map'
7474
, filter
75+
, mapMaybe
7576
, traverse
7677
, traverse'
7778
, toList
@@ -519,6 +520,25 @@ filter f = \ ary ->
519520
else go_filter ary mary (iAry + 1) iMary n
520521
{-# INLINE filter #-}
521522

523+
mapMaybe :: (a -> Maybe b) -> Array a -> Array b
524+
mapMaybe f = \ ary ->
525+
let !n = length ary
526+
in run $ do
527+
mary <- new_ n
528+
len <- go_mapMaybe ary mary 0 0 n
529+
shrink mary len
530+
where
531+
go_mapMaybe !ary !mary !iAry !iMary !n
532+
| iAry >= n = return iMary
533+
| otherwise = do
534+
x <- indexM ary iAry
535+
case f x of
536+
Nothing -> go_mapMaybe ary mary (iAry + 1) iMary n
537+
Just y -> do
538+
write mary iMary y
539+
go_mapMaybe ary mary (iAry + 1) (iMary + 1) n
540+
{-# INLINE mapMaybe #-}
541+
522542
fromList :: Int -> [a] -> Array a
523543
fromList n xs0 =
524544
CHECK_EQ("fromList", n, Prelude.length xs0)

Data/HashMap/Internal/Strict.hs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ module Data.HashMap.Internal.Strict
9090
-- * Difference and intersection
9191
, HM.difference
9292
, differenceWith
93+
, differenceWithKey
9394
, HM.intersection
9495
, intersectionWith
9596
, intersectionWithKey
@@ -617,12 +618,22 @@ traverseWithKey f = go
617618
-- If it returns 'Nothing', the element is discarded (proper set difference). If
618619
-- it returns (@'Just' y@), the element is updated with a new value @y@.
619620
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
620-
differenceWith f a b = HM.foldlWithKey' go HM.empty a
621-
where
622-
go m k v = case HM.lookup k b of
623-
Nothing -> v `seq` HM.unsafeInsert k v m
624-
Just w -> maybe m (\ !y -> HM.unsafeInsert k y m) (f v w)
625-
{-# INLINABLE differenceWith #-}
621+
differenceWith f = HM.differenceWithKey $
622+
\_k vA vB -> case f vA vB of
623+
Nothing -> Nothing
624+
x@(Just v) -> v `seq` x
625+
{-# INLINE differenceWith #-}
626+
627+
-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
628+
-- encountered, the combining function is applied to the values of these keys.
629+
-- If it returns 'Nothing', the element is discarded (proper set difference). If
630+
-- it returns (@'Just' y@), the element is updated with a new value @y@.
631+
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
632+
differenceWithKey f = HM.differenceWithKey $
633+
\k vA vB -> case f k vA vB of
634+
Nothing -> Nothing
635+
x@(Just v) -> v `seq` x
636+
{-# INLINE differenceWithKey #-}
626637

627638
-- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps
628639
-- the provided function is used to combine the values from the two

Data/HashMap/Lazy.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ module Data.HashMap.Lazy
7171
-- * Difference and intersection
7272
, difference
7373
, differenceWith
74+
, differenceWithKey
7475
, intersection
7576
, intersectionWith
7677
, intersectionWithKey

Data/HashMap/Strict.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ module Data.HashMap.Strict
7070
-- * Difference and intersection
7171
, difference
7272
, differenceWith
73+
, differenceWithKey
7374
, intersection
7475
, intersectionWith
7576
, intersectionWithKey

benchmarks/FineGrained.hs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ main =
3434
bUnion,
3535
bUnions,
3636
bIntersection,
37-
bDifference
37+
bDifference,
38+
bDifferenceWith
3839
],
3940
bgroup "HashSet" [bSetFromList]
4041
]
@@ -347,6 +348,42 @@ bDifferenceEqual =
347348
where
348349
b size = bench (show size) . whnf (\m -> HM.difference m m)
349350

351+
bDifferenceWith :: Benchmark
352+
bDifferenceWith =
353+
bgroup
354+
"differenceWith"
355+
[ bgroup "disjoint" bDifferenceWithDisjoint,
356+
bgroup "overlap" bDifferenceWithOverlap,
357+
bgroup "equal" bDifferenceWithEqual
358+
]
359+
360+
differenceWithF :: Int -> Int -> Maybe Int
361+
differenceWithF x y = Just (x + y)
362+
363+
bDifferenceWithDisjoint :: [Benchmark]
364+
bDifferenceWithDisjoint =
365+
[ bgroup' "Bytes" genBytesMapsDisjoint b,
366+
bgroup' "Int" genIntMapsDisjoint b
367+
]
368+
where
369+
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)
370+
371+
bDifferenceWithOverlap :: [Benchmark]
372+
bDifferenceWithOverlap =
373+
[ bgroup' "Bytes" genBytesMapsOverlap b,
374+
bgroup' "Int" genIntMapsOverlap b
375+
]
376+
where
377+
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)
378+
379+
bDifferenceWithEqual :: [Benchmark]
380+
bDifferenceWithEqual =
381+
[ bgroup' "Bytes" genBytesMap b,
382+
bgroup' "Int" genIntMap b
383+
]
384+
where
385+
b size = bench (show size) . whnf (\m -> HM.differenceWith differenceWithF m m)
386+
350387
bSetFromList :: Benchmark
351388
bSetFromList =
352389
bgroup

tests/Properties/HashMapLazy.hs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,21 @@ tests =
310310
toOrdMap (HM.differenceWith f x y) === M.differenceWith f (toOrdMap x) (toOrdMap y)
311311
, testProperty "valid" $
312312
\(Fn2 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWith f x y)
313+
, testProperty "differenceWith (\\x y -> Just $ f x y) xs ys == intersectionWith f xs ys `union` xs" $
314+
\(Fn2 f) (x :: HMK A) (y :: HMK B) ->
315+
HM.differenceWith (\a b -> Just $ f a b) x y
316+
=== HM.intersectionWith f x y `HM.union` x
317+
]
318+
, testGroup "differenceWithKey"
319+
[ testProperty "model" $
320+
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
321+
toOrdMap (HM.differenceWithKey f x y) === M.differenceWithKey f (toOrdMap x) (toOrdMap y)
322+
, testProperty "valid" $
323+
\(Fn3 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWithKey f x y)
324+
, testProperty "differenceWithKey (\\k x y -> Just $ f k x y) xs ys == intersectionWithKey f xs ys `union` xs" $
325+
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
326+
HM.differenceWithKey (\k a b -> Just $ f k a b) x y
327+
=== HM.intersectionWithKey f x y `HM.union` x
313328
]
314329
, testGroup "intersection"
315330
[ testProperty "model" $

0 commit comments

Comments
 (0)