Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 181 additions & 6 deletions Data/HashMap/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ module Data.HashMap.Internal
-- * Difference and intersection
, difference
, differenceWith
, differenceWithKey
, intersection
, intersectionWith
, intersectionWithKey
Expand Down Expand Up @@ -1917,13 +1918,187 @@ differenceCollisions !h1 !ary1 t1 !h2 !ary2
-- encountered, the combining function is applied to the values of these keys.
-- If it returns 'Nothing', the element is discarded (proper set difference). If
-- it returns (@'Just' y@), the element is updated with a new value @y@.
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
differenceWith f a b = foldlWithKey' go empty a
differenceWith :: Eq k => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
differenceWith f = differenceWithKey (const f)
{-# INLINE differenceWith #-}

-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are m and n here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n is the size of the first map. m is the size of the second map. This is a convention this package uses for many functions. I suspect it was adopted from containers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, but I don't think that's necessarily what this implementation does; it was left unchanged from the old one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obviously wrong to me at least. If the first map is small, the second is a relatively large superset, and lookup[Cont] takes log(m), we still do n lookups in the larger map. To be fair, we don't start these lookups at the root, so maybe O(n log (m/n)) would be more accurate?!

IMHO these log(size)s are not very useful anyways, since on 64-bit systems we have a maximum tree height of 13, and on 32-bit systems the maximum tree height is 8; and you can still have a map with two entries and full tree height…

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bounds are given assuming sufficiently uniform hashing, but that's not at all the case for important instances like Int. It's ... a problem. I can't say if n log (m/n) is accurate or not, but it should be something symmetrical!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take that back. Maybe not symmetrical. But ... I dunno...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not symmetrical actually. I have no idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have opened #543 to track this.

-- encountered, the combining function is applied to the values of these keys.
-- If it returns 'Nothing', the element is discarded (proper set difference). If
-- it returns (@'Just' y@), the element is updated with a new value @y@.
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
differenceWithKey f = go_differenceWithKey 0
where
go m k v = case lookup k b of
Nothing -> unsafeInsert k v m
Just w -> maybe m (\y -> unsafeInsert k y m) (f v w)
{-# INLINABLE differenceWith #-}
go_differenceWithKey !_s Empty _tB = Empty
go_differenceWithKey _s a Empty = a
go_differenceWithKey s a@(Leaf hA (L kA vA)) b
= lookupCont
(\_ -> a)
(\vB _ -> case f kA vA vB of
Nothing -> Empty
Just v | v `ptrEq` vA -> a
| otherwise -> Leaf hA (L kA v))
hA kA s b
go_differenceWithKey _s a@(Collision hA aryA) (Leaf hB (L kB vB))
| hA == hB = updateCollision (\vA -> f kB vA vB) hA kB aryA a
| otherwise = a
go_differenceWithKey s a@(BitmapIndexed bA aryA) b@(Leaf hB _)
| bA .&. m == 0 = a
| otherwise = case A.index# aryA i of
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
Empty | A.length aryA == 2
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
, isLeafOrCollision l
-> l
| otherwise
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
stA' | isLeafOrCollision stA' && A.length aryA == 1 -> stA'
| stA `ptrEq` stA' -> a
| otherwise -> BitmapIndexed bA (A.update aryA i stA')
where
m = mask hB s
i = sparseIndex bA m
go_differenceWithKey s a@(BitmapIndexed bA aryA) b@(Collision hB _)
| bA .&. m == 0 = a
| otherwise =
case A.index# aryA i of
(# !st #) -> case go_differenceWithKey (nextShift s) st b of
Empty | A.length aryA == 2
, (# l #) <- A.index# aryA (otherOfOneOrZero i)
, isLeafOrCollision l
-> l
| otherwise
-> BitmapIndexed (bA .&. complement m) (A.delete aryA i)
st' | isLeafOrCollision st' && A.length aryA == 1 -> st'
| st `ptrEq` st' -> a
| otherwise -> BitmapIndexed bA (A.update aryA i st')
where
m = mask hB s
i = sparseIndex bA m
go_differenceWithKey s a@(Full aryA) b@(Leaf hB _)
= case A.index# aryA i of
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
Empty ->
let aryA' = A.delete aryA i
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
in BitmapIndexed bm aryA'
stA' | stA `ptrEq` stA' -> a
| otherwise -> Full (updateFullArray aryA i stA')
where i = index hB s
go_differenceWithKey s a@(Full aryA) b@(Collision hB _)
= case A.index# aryA i of
(# !stA #) -> case go_differenceWithKey (nextShift s) stA b of
Empty ->
let aryA' = A.delete aryA i
bm = fullBitmap .&. complement (1 `unsafeShiftL` i)
in BitmapIndexed bm aryA'
stA' | stA `ptrEq` stA' -> a
| otherwise -> Full (updateFullArray aryA i stA')
where i = index hB s
go_differenceWithKey s a@(Collision hA _) (BitmapIndexed bB aryB)
| bB .&. m == 0 = a
| otherwise =
case A.index# aryB (sparseIndex bB m) of
(# stB #) -> go_differenceWithKey (nextShift s) a stB
where m = mask hA s
go_differenceWithKey s a@(Collision hA _) (Full aryB)
= case A.index# aryB (index hA s) of
(# stB #) -> go_differenceWithKey (nextShift s) a stB
go_differenceWithKey s a@(BitmapIndexed bA aryA) (BitmapIndexed bB aryB)
= differenceWithKey_Arrays s bA aryA a bB aryB
go_differenceWithKey s a@(Full aryA) (BitmapIndexed bB aryB)
= differenceWithKey_Arrays s fullBitmap aryA a bB aryB
go_differenceWithKey s a@(BitmapIndexed bA aryA) (Full aryB)
= differenceWithKey_Arrays s bA aryA a fullBitmap aryB
go_differenceWithKey s a@(Full aryA) (Full aryB)
= differenceWithKey_Arrays s fullBitmap aryA a fullBitmap aryB
go_differenceWithKey _s a@(Collision hA aryA) (Collision hB aryB)
= differenceWithKey_Collisions f hA aryA a hB aryB

differenceWithKey_Arrays !s !bA !aryA tA !bB !aryB
| bA .&. bB == 0 = tA
| otherwise = runST $ do
mary <- A.new_ $ A.length aryA

-- TODO: i == popCount bResult. Not sure if that would be faster.
-- Also iA is in some relation with bA'
let go_dWKA !i !iA !bA' !bResult !nChanges
| bA' == 0 = pure (bResult, nChanges)
| otherwise = do
!stA <- A.indexM aryA iA
case m .&. bB of
0 -> do
A.write mary i stA
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges
_ -> do
!stB <- A.indexM aryB (sparseIndex bB m)
case go_differenceWithKey (nextShift s) stA stB of
Empty -> go_dWKA i (iA + 1) nextBA' bResult (nChanges + 1)
st -> do
A.write mary i st
let same = I# (Exts.reallyUnsafePtrEquality# st stA)
let nChanges' = nChanges + (1 - same)
go_dWKA (i + 1) (iA + 1) nextBA' (bResult .|. m) nChanges'
where
m = bA' .&. negate bA'
nextBA' = bA' .&. complement m

(bResult, nChanges) <- go_dWKA 0 0 bA 0 0
if nChanges == 0
then pure tA
else case popCount bResult of
0 -> pure Empty
1 -> do
l <- A.read mary 0
if isLeafOrCollision l
then pure l
else BitmapIndexed bResult <$> (A.unsafeFreeze =<< A.shrink mary 1)
n -> bitmapIndexedOrFull bResult <$> (A.unsafeFreeze =<< A.shrink mary n)
{-# INLINE differenceWithKey #-}

-- | 'update', specialized to 'Collision' nodes.
updateCollision
:: Eq k
=> (v -> Maybe v)
-> Hash
-> k
-> A.Array (Leaf k v)
-> HashMap k v
-- ^ The original Collision node which will be re-used if the array is unchanged.
-> HashMap k v
updateCollision f !h k !ary orig =
lookupInArrayCont
(\_ -> orig)
(\v i -> case f v of
Nothing | A.length ary == 2
, (# l #) <- A.index# ary (otherOfOneOrZero i)
-> Leaf h l
| otherwise -> Collision h (A.delete ary i)
Just v' | v' `ptrEq` v -> orig
| otherwise -> Collision h (A.update ary i (L k v')))
k ary
{-# INLINABLE updateCollision #-}

-- TODO: This could be faster if we would keep track of which elements of ary2
-- we've already matched. Those could be skipped when we check the following
-- elements of ary1.
-- TODO: Return tA when the array is unchanged.
differenceWithKey_Collisions :: Eq k => (k -> v -> w -> Maybe v) -> Word -> A.Array (Leaf k v) -> HashMap k v -> Word -> A.Array (Leaf k w) -> HashMap k v
differenceWithKey_Collisions f !hA !aryA !tA !hB !aryB
| hA == hB =
let f' l@(L kA vA) =
lookupInArrayCont
(\_ -> Just l)
(\vB _ -> L kA <$> f kA vA vB)
kA aryB
ary = A.mapMaybe f' aryA
in case A.length ary of
0 -> Empty
1 -> case A.index# ary 0 of
(# l #) -> Leaf hA l
_ -> Collision hA ary
| otherwise = tA
{-# INLINABLE differenceWithKey_Collisions #-}

-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
-- map for keys existing in the second.
Expand Down
20 changes: 20 additions & 0 deletions Data/HashMap/Internal/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ module Data.HashMap.Internal.Array
, map
, map'
, filter
, mapMaybe
, traverse
, traverse'
, toList
Expand Down Expand Up @@ -519,6 +520,25 @@ filter f = \ ary ->
else go_filter ary mary (iAry + 1) iMary n
{-# INLINE filter #-}

mapMaybe :: (a -> Maybe b) -> Array a -> Array b
mapMaybe f = \ ary ->
let !n = length ary
in run $ do
mary <- new_ n
len <- go_mapMaybe ary mary 0 0 n
shrink mary len
where
go_mapMaybe !ary !mary !iAry !iMary !n
| iAry >= n = return iMary
| otherwise = do
x <- indexM ary iAry
case f x of
Nothing -> go_mapMaybe ary mary (iAry + 1) iMary n
Just y -> do
write mary iMary y
go_mapMaybe ary mary (iAry + 1) (iMary + 1) n
{-# INLINE mapMaybe #-}

fromList :: Int -> [a] -> Array a
fromList n xs0 =
CHECK_EQ("fromList", n, Prelude.length xs0)
Expand Down
23 changes: 17 additions & 6 deletions Data/HashMap/Internal/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ module Data.HashMap.Internal.Strict
-- * Difference and intersection
, HM.difference
, differenceWith
, differenceWithKey
, HM.intersection
, intersectionWith
, intersectionWithKey
Expand Down Expand Up @@ -617,12 +618,22 @@ traverseWithKey f = go
-- If it returns 'Nothing', the element is discarded (proper set difference). If
-- it returns (@'Just' y@), the element is updated with a new value @y@.
differenceWith :: (Eq k, Hashable k) => (v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
differenceWith f a b = HM.foldlWithKey' go HM.empty a
where
go m k v = case HM.lookup k b of
Nothing -> v `seq` HM.unsafeInsert k v m
Just w -> maybe m (\ !y -> HM.unsafeInsert k y m) (f v w)
{-# INLINABLE differenceWith #-}
differenceWith f = HM.differenceWithKey $
\_k vA vB -> case f vA vB of
Nothing -> Nothing
x@(Just v) -> v `seq` x
{-# INLINE differenceWith #-}

-- | \(O(n \log m)\) Difference with a combining function. When two equal keys are
-- encountered, the combining function is applied to the values of these keys.
-- If it returns 'Nothing', the element is discarded (proper set difference). If
-- it returns (@'Just' y@), the element is updated with a new value @y@.
differenceWithKey :: Eq k => (k -> v -> w -> Maybe v) -> HashMap k v -> HashMap k w -> HashMap k v
differenceWithKey f = HM.differenceWithKey $
\k vA vB -> case f k vA vB of
Nothing -> Nothing
x@(Just v) -> v `seq` x
{-# INLINE differenceWithKey #-}

-- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps
-- the provided function is used to combine the values from the two
Expand Down
1 change: 1 addition & 0 deletions Data/HashMap/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ module Data.HashMap.Lazy
-- * Difference and intersection
, difference
, differenceWith
, differenceWithKey
, intersection
, intersectionWith
, intersectionWithKey
Expand Down
1 change: 1 addition & 0 deletions Data/HashMap/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ module Data.HashMap.Strict
-- * Difference and intersection
, difference
, differenceWith
, differenceWithKey
, intersection
, intersectionWith
, intersectionWithKey
Expand Down
39 changes: 38 additions & 1 deletion benchmarks/FineGrained.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ main =
bUnion,
bUnions,
bIntersection,
bDifference
bDifference,
bDifferenceWith
],
bgroup "HashSet" [bSetFromList]
]
Expand Down Expand Up @@ -347,6 +348,42 @@ bDifferenceEqual =
where
b size = bench (show size) . whnf (\m -> HM.difference m m)

bDifferenceWith :: Benchmark
bDifferenceWith =
bgroup
"differenceWith"
[ bgroup "disjoint" bDifferenceWithDisjoint,
bgroup "overlap" bDifferenceWithOverlap,
bgroup "equal" bDifferenceWithEqual
]

differenceWithF :: Int -> Int -> Maybe Int
differenceWithF x y = Just (x + y)

bDifferenceWithDisjoint :: [Benchmark]
bDifferenceWithDisjoint =
[ bgroup' "Bytes" genBytesMapsDisjoint b,
bgroup' "Int" genIntMapsDisjoint b
]
where
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)

bDifferenceWithOverlap :: [Benchmark]
bDifferenceWithOverlap =
[ bgroup' "Bytes" genBytesMapsOverlap b,
bgroup' "Int" genIntMapsOverlap b
]
where
b size = bench (show size) . whnf (\(xs, ys) -> HM.differenceWith differenceWithF xs ys)

bDifferenceWithEqual :: [Benchmark]
bDifferenceWithEqual =
[ bgroup' "Bytes" genBytesMap b,
bgroup' "Int" genIntMap b
]
where
b size = bench (show size) . whnf (\m -> HM.differenceWith differenceWithF m m)

bSetFromList :: Benchmark
bSetFromList =
bgroup
Expand Down
15 changes: 15 additions & 0 deletions tests/Properties/HashMapLazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,21 @@ tests =
toOrdMap (HM.differenceWith f x y) === M.differenceWith f (toOrdMap x) (toOrdMap y)
, testProperty "valid" $
\(Fn2 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWith f x y)
, testProperty "differenceWith (\\x y -> Just $ f x y) xs ys == intersectionWith f xs ys `union` xs" $
\(Fn2 f) (x :: HMK A) (y :: HMK B) ->
HM.differenceWith (\a b -> Just $ f a b) x y
=== HM.intersectionWith f x y `HM.union` x
]
, testGroup "differenceWithKey"
[ testProperty "model" $
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
toOrdMap (HM.differenceWithKey f x y) === M.differenceWithKey f (toOrdMap x) (toOrdMap y)
, testProperty "valid" $
\(Fn3 f) (x :: HMK A) (y :: HMK B) -> isValid (HM.differenceWithKey f x y)
, testProperty "differenceWithKey (\\k x y -> Just $ f k x y) xs ys == intersectionWithKey f xs ys `union` xs" $
\(Fn3 f) (x :: HMK A) (y :: HMK B) ->
HM.differenceWithKey (\k a b -> Just $ f k a b) x y
=== HM.intersectionWithKey f x y `HM.union` x
]
, testGroup "intersection"
[ testProperty "model" $
Expand Down