|
2 | 2 | """
|
3 | 3 | A module that extends pandas to support the ROOT data format.
|
4 | 4 | """
|
| 5 | +from collections import Counter |
5 | 6 |
|
6 | 7 | import numpy as np
|
7 | 8 | from numpy.lib.recfunctions import append_fields
|
@@ -95,11 +96,11 @@ def get_nonscalar_columns(array):
|
95 | 96 | def get_matching_variables(branches, patterns, fail=True):
|
96 | 97 | # Convert branches to a set to make x "in branches" O(1) on average
|
97 | 98 | branches = set(branches)
|
98 |
| - patterns = set(patterns) |
99 | 99 | # Find any trivial matches
|
100 |
| - selected = list(branches.intersection(patterns)) |
| 100 | + selected = sorted(branches.intersection(patterns), |
| 101 | + key=lambda s: patterns.index(s)) |
101 | 102 | # Any matches that weren't trivial need to be looped over...
|
102 |
| - for pattern in patterns.difference(selected): |
| 103 | + for pattern in set(patterns).difference(selected): |
103 | 104 | found = False
|
104 | 105 | # Avoid using fnmatch if the pattern if possible
|
105 | 106 | if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern):
|
@@ -317,7 +318,7 @@ def convert_to_dataframe(array, start_index=None):
|
317 | 318 | # Filter to remove __index__ columns
|
318 | 319 | columns = [c for c in array.dtype.names if c in df.columns]
|
319 | 320 | assert len(columns) == len(df.columns), (columns, df.columns)
|
320 |
| - df = df.reindex_axis(columns, axis=1, copy=False) |
| 321 | + df = df.reindex(columns, axis=1, copy=False) |
321 | 322 |
|
322 | 323 | # Convert categorical columns back to categories
|
323 | 324 | for c in df.columns:
|
@@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
|
366 | 367 | else:
|
367 | 368 | raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode))
|
368 | 369 |
|
| 370 | + column_name_counts = Counter(df.columns) |
| 371 | + if max(column_name_counts.values()) > 1: |
| 372 | + raise ValueError('DataFrame contains duplicated column names: ' + |
| 373 | + ' '.join({k for k, v in column_name_counts.items() if v > 1})) |
| 374 | + |
369 | 375 | from root_numpy import array2tree
|
370 | 376 | # We don't want to modify the user's DataFrame here, so we make a shallow copy
|
371 | 377 | df_ = df.copy(deep=False)
|
|
0 commit comments