@@ -280,19 +280,28 @@ def _indicator_post_merge(self, result):
280280 return result
281281
282282 def _maybe_add_join_keys (self , result , left_indexer , right_indexer ):
283- # insert group keys
283+
284+ consolidate = False
285+
286+ left_has_missing = None
287+ right_has_missing = None
284288
285289 keys = zip (self .join_names , self .left_on , self .right_on )
286290 for i , (name , lname , rname ) in enumerate (keys ):
287291 if not _should_fill (lname , rname ):
288292 continue
289293
294+ take_left , take_right = None , None
295+
290296 if name in result :
297+ < << << << HEAD
291298 key_indexer = result .columns .get_loc (name )
299+ == == == =
300+ >> >> >> > e79b978 ... Preserve dtype in merge keys when possible
292301
293302 if left_indexer is not None and right_indexer is not None :
294-
295303 if name in self .left :
304+ < << << << HEAD
296305 if len (self .left ) == 0 :
297306 continue
298307
@@ -316,19 +325,71 @@ def _maybe_add_join_keys(self, result, left_indexer, right_indexer):
316325 result .iloc [na_indexer , key_indexer ] = (
317326 algos .take_1d (self .left_join_keys [i ],
318327 left_na_indexer ))
328+ == == == =
329+
330+ if left_has_missing is None :
331+ left_has_missing = any (left_indexer == - 1 )
332+
333+ if left_has_missing :
334+ take_right = self .right_join_keys [i ]
335+
336+ if result [name ].dtype != self .left [name ].dtype :
337+ take_left = self .left [name ].values
338+
339+ elif name in self .right :
340+
341+ if right_has_missing is None :
342+ right_has_missing = any (right_indexer == - 1 )
343+
344+ if right_has_missing :
345+ take_left = self .left_join_keys [i ]
346+
347+ if result [name ].dtype != self .right [name ].dtype :
348+ take_right = self .right [name ].values
349+
350+ > >> >> >> e79b978 ... Preserve dtype in merge keys when possible
319351 elif left_indexer is not None \
320352 and isinstance (self .left_join_keys [i ], np .ndarray ):
321353
322- if name is None :
323- name = 'key_%d' % i
354+ take_left = self .left_join_keys [i ]
355+ take_right = self .right_join_keys [i ]
356+
357+ if take_left is not None or take_right is not None :
358+
359+ if take_left is None :
360+ lvals = result [name ].values
361+ else :
362+ lfill = take_left .dtype .type ()
363+ lvals = com .take_1d (take_left , left_indexer , fill_value = lfill )
364+
365+ if take_right is None :
366+ rvals = result [name ].values
367+ else :
368+ rfill = take_right .dtype .type ()
369+ rvals = com .take_1d (take_right , right_indexer , fill_value = rfill )
370+
371+ key_col = np .where (left_indexer != - 1 , lvals , rvals )
372+
373+ if name in result :
374+ if result [name ].dtype != key_col .dtype :
375+ consolidate = True
376+ result [name ] = key_col
377+ else :
378+ result .insert (i , name or 'key_%d' % i , key_col )
379+ consolidate = True
324380
381+ << << << < HEAD
325382 # a faster way?
326383 key_col = algos .take_1d (self .left_join_keys [i ], left_indexer )
327384 na_indexer = (left_indexer == - 1 ).nonzero ()[0 ]
328385 right_na_indexer = right_indexer .take (na_indexer )
329386 key_col .put (na_indexer , algos .take_1d (self .right_join_keys [i ],
330387 right_na_indexer ))
331388 result .insert (i , name , key_col )
389+ == == == =
390+ if consolidate :
391+ result .consolidate (inplace = True )
392+ > >> >> >> e79b978 ... Preserve dtype in merge keys when possible
332393
333394 def _get_join_info (self ):
334395 left_ax = self .left ._data .axes [self .axis ]
0 commit comments