Skip to content

Commit 0c0d1f0

Browse files
committed
Add differenceWithKey
...and define `differenceWith` via `differenceWithKey` Closes #389.
1 parent 5ce9758 commit 0c0d1f0

File tree

7 files changed

+255
-12
lines changed

7 files changed

+255
-12
lines changed

Data/HashMap/Internal.hs

Lines changed: 168 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ module Data.HashMap.Internal
7777
-- * Difference and intersection
7878
, difference
7979
, differenceWith
80+
, differenceWithKey
8081
, intersection
8182
, intersectionWith
8283
, intersectionWithKey
@@ -1917,14 +1918,175 @@ differenceCollisions !h1 !ary1 t1 !h2 !ary2
19171918
-- encountered, the combining function is applied to the values of these keys.
19181919
-- If it returns 'Nothing', the element is discarded (proper set difference). If
19191920
-- it returns (@'Just' y@), the element is updated with a new value @y@.
1920-
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1921-
differenceWith f a b = foldlWithKey' go empty a
1922-
where
1923-
go m k v = case lookup k b of
1924-
Nothing -> unsafeInsert k v m
1925-
Just w -> maybe m (\y -> unsafeInsert k y m) (f v w)
1921+
differenceWith :: Eq k => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1922+
differenceWith f = differenceWithKey (const f)
19261923
{-# INLINABLE differenceWith #-}
19271924

1925+
-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
1926+
-- encountered, the combining function is applied to the values of these keys.
1927+
-- If it returns 'Nothing', the element is discarded (proper set difference). If
1928+
-- it returns (@'Just' y@), the element is updated with a new value @y@.
1929+
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
1930+
differenceWithKey = go_differenceWithKey 0
1931+
where
1932+
go_differenceWithKey !_s _f Empty _tB = Empty
1933+
go_differenceWithKey _s _f a Empty = a
1934+
go_differenceWithKey s f a@(Leaf hA (L kA vA)) b
1935+
= lookupCont
1936+
(\_ -> a)
1937+
(\vB _ -> case f kA vA vB of
1938+
Nothing -> Empty
1939+
Just v | v `ptrEq` vA -> a
1940+
| otherwise -> Leaf hA (L kA v))
1941+
hA kA s b
1942+
go_differenceWithKey _s f a@(Collision hA aryA) (Leaf hB (L kB vB))
1943+
| hA == hB
1944+
= lookupInArrayCont
1945+
(\_ -> a)
1946+
(\vA i -> case f kB vA vB of
1947+
Nothing | A.length aryA == 2
1948+
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
1949+
-> Leaf hA l
1950+
| otherwise -> Collision hA (A.delete aryA i)
1951+
Just v | v `ptrEq` vA -> a
1952+
| otherwise -> Collision hA (A.update aryA i (L kB v)))
1953+
kB aryA
1954+
| otherwise = a
1955+
go_differenceWithKey s f a@(BitmapIndexed bA aryA) b@(Leaf hB _)
1956+
| bA .&. m == 0 = a
1957+
| otherwise = case A.index# aryA i of
1958+
(# !stA #) -> case go_differenceWithKey (nextShift s) f stA b of
1959+
Empty | A.length aryA == 2
1960+
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
1961+
, isLeafOrCollision l
1962+
-> l
1963+
| otherwise
1964+
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
1965+
stA' | isLeafOrCollision stA' && A.length aryA == 1 -> stA'
1966+
| stA `ptrEq` stA' -> a
1967+
| otherwise -> BitmapIndexed bA (A.update aryA i stA')
1968+
where
1969+
m = mask hB s
1970+
i = sparseIndex bA m
1971+
go_differenceWithKey s f a@(BitmapIndexed bA aryA) b@(Collision hB _)
1972+
| bA .&. m == 0 = a
1973+
| otherwise =
1974+
case A.index# aryA i of
1975+
(# !st #) -> case go_differenceWithKey (nextShift s) f st b of
1976+
Empty | A.length aryA == 2
1977+
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
1978+
, isLeafOrCollision l
1979+
-> l
1980+
| otherwise
1981+
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
1982+
st' | isLeafOrCollision st' && A.length aryA == 1 -> st'
1983+
| st `ptrEq` st' -> a
1984+
| otherwise -> BitmapIndexed bA (A.update aryA i st')
1985+
where
1986+
m = mask hB s
1987+
i = sparseIndex bA m
1988+
go_differenceWithKey s f a@(Full aryA) b@(Leaf hB _)
1989+
= case A.index# aryA i of
1990+
(# !stA #) -> case go_differenceWithKey (nextShift s) f stA b of
1991+
Empty ->
1992+
let aryA' = A.delete aryA i
1993+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
1994+
in BitmapIndexed bm aryA'
1995+
stA' | stA `ptrEq` stA' -> a
1996+
| otherwise -> Full (updateFullArray aryA i stA')
1997+
where i = index hB s
1998+
go_differenceWithKey s f a@(Full aryA) b@(Collision hB _)
1999+
= case A.index# aryA i of
2000+
(# !stA #) -> case go_differenceWithKey (nextShift s) f stA b of
2001+
Empty ->
2002+
let aryA' = A.delete aryA i
2003+
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
2004+
in BitmapIndexed bm aryA'
2005+
stA' | stA `ptrEq` stA' -> a
2006+
| otherwise -> Full (updateFullArray aryA i stA')
2007+
where i = index hB s
2008+
go_differenceWithKey s f a@(Collision hA _) (BitmapIndexed bB aryB)
2009+
| bB .&. m == 0 = a
2010+
| otherwise =
2011+
case A.index# aryB (sparseIndex bB m) of
2012+
(# stB #) -> go_differenceWithKey (nextShift s) f a stB
2013+
where m = mask hA s
2014+
go_differenceWithKey s f a@(Collision hA _) (Full aryB)
2015+
= case A.index# aryB (index hA s) of
2016+
(# stB #) -> go_differenceWithKey (nextShift s) f a stB
2017+
go_differenceWithKey s f a@(BitmapIndexed bA aryA) (BitmapIndexed bB aryB)
2018+
= differenceWithKey_Arrays s f bA aryA a bB aryB
2019+
go_differenceWithKey s f a@(Full aryA) (BitmapIndexed bB aryB)
2020+
= differenceWithKey_Arrays s f fullBitmap aryA a bB aryB
2021+
go_differenceWithKey s f a@(BitmapIndexed bA aryA) (Full aryB)
2022+
= differenceWithKey_Arrays s f bA aryA a fullBitmap aryB
2023+
go_differenceWithKey s f a@(Full aryA) (Full aryB)
2024+
= differenceWithKey_Arrays s f fullBitmap aryA a fullBitmap aryB
2025+
go_differenceWithKey _s f a@(Collision hA aryA) (Collision hB aryB)
2026+
= differenceWithKey_Collisions f hA aryA a hB aryB
2027+
2028+
differenceWithKey_Arrays !s f !bA !aryA tA !bB !aryB
2029+
| bA .&. bB == 0 = tA
2030+
| otherwise = runST $ do
2031+
mary <- A.new_ $ A.length aryA
2032+
2033+
-- TODO: i == popCount bResult. Not sure if that would be faster.
2034+
-- Also iA is in some relation with bA'
2035+
let go_dWKA !i !iA !bA' !bResult !nChanges
2036+
| bA' == 0 = pure (bResult, nChanges)
2037+
| otherwise = do
2038+
!stA <- A.indexM aryA iA
2039+
case m .&. bB of
2040+
0 -> do
2041+
A.write mary i stA
2042+
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges
2043+
_ -> do
2044+
!stB <- A.indexM aryB (sparseIndex bB m)
2045+
case go_differenceWithKey (nextShift s) f stA stB of
2046+
Empty -> go_dWKA i (iA + 1) nextBA' bResult (nChanges + 1)
2047+
st -> do
2048+
A.write mary i st
2049+
let same = I# (Exts.reallyUnsafePtrEquality# st stA)
2050+
let nChanges' = nChanges + (1 - same)
2051+
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges'
2052+
where
2053+
m = bA' .&. negate bA'
2054+
nextBA' = bA' .&. complement m
2055+
2056+
(bResult, nChanges) <- go_dWKA 0 0 bA 0 0
2057+
if nChanges == 0
2058+
then pure tA
2059+
else case popCount bResult of
2060+
0 -> pure Empty
2061+
1 -> do
2062+
l <- A.read mary 0
2063+
if isLeafOrCollision l
2064+
then pure l
2065+
else BitmapIndexed bResult <$> (A.unsafeFreeze =<< A.shrink mary 1)
2066+
n -> bitmapIndexedOrFull bResult <$> (A.unsafeFreeze =<< A.shrink mary n)
2067+
{-# INLINABLE differenceWithKey #-}
2068+
2069+
-- TODO: This could be faster if we would keep track of which elements of ary2
2070+
-- we've already matched. Those could be skipped when we check the following
2071+
-- elements of ary1.
2072+
-- TODO: Return tA when the array is unchanged.
2073+
differenceWithKey_Collisions :: Eq k => (k -> v -> w -> Maybe v) -> Word -> A.Array (Leaf k v) -> HashMap k v -> Word -> A.Array (Leaf k w) -> HashMap k v
2074+
differenceWithKey_Collisions f !hA !aryA !tA !hB !aryB
2075+
| hA == hB =
2076+
let f' l@(L kA vA) =
2077+
lookupInArrayCont
2078+
(\_ -> Just l)
2079+
(\vB _ -> L kA <$> f kA vA vB)
2080+
kA aryB
2081+
ary = A.mapMaybe f' aryA
2082+
in case A.length ary of
2083+
0 -> Empty
2084+
1 -> case A.index# ary 0 of
2085+
(# l #) -> Leaf hA l
2086+
_ -> Collision hA ary
2087+
| otherwise = tA
2088+
{-# INLINABLE differenceWithKey_Collisions #-}
2089+
19282090
-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
19292091
-- map for keys existing in the second.
19302092
intersection :: Eq k => HashMap k v -> HashMap k w -> HashMap k v

Data/HashMap/Internal/Array.hs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ module Data.HashMap.Internal.Array
7272
, map
7373
, map'
7474
, filter
75+
, mapMaybe
7576
, traverse
7677
, traverse'
7778
, toList
@@ -519,6 +520,25 @@ filter f = \ ary ->
519520
else go_filter ary mary (iAry + 1) iMary n
520521
{-# INLINE filter #-}
521522

523+
mapMaybe :: (a -> Maybe b) -> Array a -> Array b
524+
mapMaybe f = \ ary ->
525+
let !n = length ary
526+
in run $ do
527+
mary <- new_ n
528+
len <- go_mapMaybe ary mary 0 0 n
529+
shrink mary len
530+
where
531+
go_mapMaybe !ary !mary !iAry !iMary !n
532+
| iAry >= n = return iMary
533+
| otherwise = do
534+
x <- indexM ary iAry
535+
case f x of
536+
Nothing -> go_mapMaybe ary mary (iAry + 1) iMary n
537+
Just y -> do
538+
write mary iMary y
539+
go_mapMaybe ary mary (iAry + 1) (iMary + 1) n
540+
{-# INLINE mapMaybe #-}
541+
522542
fromList :: Int -> [a] -> Array a
523543
fromList n xs0 =
524544
CHECK_EQ("fromList", n, Prelude.length xs0)

Data/HashMap/Internal/Strict.hs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ module Data.HashMap.Internal.Strict
9090
-- * Difference and intersection
9191
, HM.difference
9292
, differenceWith
93+
, differenceWithKey
9394
, HM.intersection
9495
, intersectionWith
9596
, intersectionWithKey
@@ -617,13 +618,19 @@ traverseWithKey f = go
617618
-- If it returns 'Nothing', the element is discarded (proper set difference). If
618619
-- it returns (@'Just' y@), the element is updated with a new value @y@.
619620
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
620-
differenceWith f a b = HM.foldlWithKey' go HM.empty a
621-
where
622-
go m k v = case HM.lookup k b of
623-
Nothing -> v `seq` HM.unsafeInsert k v m
624-
Just w -> maybe m (\ !y -> HM.unsafeInsert k y m) (f v w)
621+
differenceWith f = HM.differenceWithKey $
622+
\_k vA vB -> case f vA vB of
623+
Nothing -> Nothing
624+
x@(Just !_v) -> x
625625
{-# INLINABLE differenceWith #-}
626626

627+
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
628+
differenceWithKey f = HM.differenceWithKey $
629+
\k vA vB -> case f k vA vB of
630+
Nothing -> Nothing
631+
x@(Just !_v) -> x
632+
{-# INLINABLE differenceWithKey #-}
633+
627634
-- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps
628635
-- the provided function is used to combine the values from the two
629636
-- maps.

Data/HashMap/Lazy.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ module Data.HashMap.Lazy
7171
-- * Difference and intersection
7272
, difference
7373
, differenceWith
74+
, differenceWithKey
7475
, intersection
7576
, intersectionWith
7677
, intersectionWithKey

Data/HashMap/Strict.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ module Data.HashMap.Strict
7070
-- * Difference and intersection
7171
, difference
7272
, differenceWith
73+
, differenceWithKey
7374
, intersection
7475
, intersectionWith
7576
, intersectionWithKey

benchmarks/FineGrained.hs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ main =
3434
bUnion,
3535
bUnions,
3636
bIntersection,
37-
bDifference
37+
bDifference,
38+
bDifferenceWith
3839
],
3940
bgroup "HashSet" [bSetFromList]
4041
]
@@ -347,6 +348,42 @@ bDifferenceEqual =
347348
where
348349
b size = bench (show size) . whnf (\m -> HM.difference m m)
349350

351+
bDifferenceWith :: Benchmark
352+
bDifferenceWith =
353+
bgroup
354+
"differenceWith"
355+
[ bgroup "disjoint" bDifferenceWithDisjoint,
356+
bgroup "overlap" bDifferenceWithOverlap,
357+
bgroup "equal" bDifferenceWithEqual
358+
]
359+
360+
differenceWithF :: Int -> Int -> Maybe Int
361+
differenceWithF x y = Just (x + y)
362+
363+
bDifferenceWithDisjoint :: [Benchmark]
364+
bDifferenceWithDisjoint =
365+
[ bgroup' "Bytes" genBytesMapsDisjoint b,
366+
bgroup' "Int" genIntMapsDisjoint b
367+
]
368+
where
369+
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)
370+
371+
bDifferenceWithOverlap :: [Benchmark]
372+
bDifferenceWithOverlap =
373+
[ bgroup' "Bytes" genBytesMapsOverlap b,
374+
bgroup' "Int" genIntMapsOverlap b
375+
]
376+
where
377+
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)
378+
379+
bDifferenceWithEqual :: [Benchmark]
380+
bDifferenceWithEqual =
381+
[ bgroup' "Bytes" genBytesMap b,
382+
bgroup' "Int" genIntMap b
383+
]
384+
where
385+
b size = bench (show size) . whnf (\m -> HM.differenceWith differenceWithF m m)
386+
350387
bSetFromList :: Benchmark
351388
bSetFromList =
352389
bgroup

tests/Properties/HashMapLazy.hs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,21 @@ tests =
310310
toOrdMap (HM.differenceWith f x y) === M.differenceWith f (toOrdMap x) (toOrdMap y)
311311
, testProperty "valid" $
312312
\(Fn2 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWith f x y)
313+
, testProperty "differenceWith (\\x y -> Just $ f x y) xs ys == intersectionWith f xs ys `union` xs" $
314+
\(Fn2 f) (x :: HMK A) (y :: HMK B) ->
315+
HM.differenceWith (\a b -> Just $ f a b) x y
316+
=== HM.intersectionWith f x y `HM.union` x
317+
]
318+
, testGroup "differenceWithKey"
319+
[ testProperty "model" $
320+
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
321+
toOrdMap (HM.differenceWithKey f x y) === M.differenceWithKey f (toOrdMap x) (toOrdMap y)
322+
, testProperty "valid" $
323+
\(Fn3 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWithKey f x y)
324+
, testProperty "differenceWithKey (\\k x y -> Just $ f k x y) xs ys == intersectionWithKey f xs ys `union` xs" $
325+
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
326+
HM.differenceWithKey (\k a b -> Just $ f k a b) x y
327+
=== HM.intersectionWithKey f x y `HM.union` x
313328
]
314329
, testGroup "intersection"
315330
[ testProperty "model" $

0 commit comments

Comments
 (0)