Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 76e820d

Browse files
authored
Merge pull request #83 from chrisburr/fix-bugs
Remove depricated use of pandas and ensure column order is correct
2 parents 57991a4 + ea2ec6b commit 76e820d

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

root_pandas/readwrite.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""
33
A module that extends pandas to support the ROOT data format.
44
"""
5+
from collections import Counter
56

67
import numpy as np
78
from numpy.lib.recfunctions import append_fields
@@ -95,11 +96,11 @@ def get_nonscalar_columns(array):
9596
def get_matching_variables(branches, patterns, fail=True):
9697
# Convert branches to a set to make x "in branches" O(1) on average
9798
branches = set(branches)
98-
patterns = set(patterns)
9999
# Find any trivial matches
100-
selected = list(branches.intersection(patterns))
100+
selected = sorted(branches.intersection(patterns),
101+
key=lambda s: patterns.index(s))
101102
# Any matches that weren't trivial need to be looped over...
102-
for pattern in patterns.difference(selected):
103+
for pattern in set(patterns).difference(selected):
103104
found = False
104105
# Avoid using fnmatch if the pattern if possible
105106
if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern):
@@ -317,7 +318,7 @@ def convert_to_dataframe(array, start_index=None):
317318
# Filter to remove __index__ columns
318319
columns = [c for c in array.dtype.names if c in df.columns]
319320
assert len(columns) == len(df.columns), (columns, df.columns)
320-
df = df.reindex_axis(columns, axis=1, copy=False)
321+
df = df.reindex(columns, axis=1, copy=False)
321322

322323
# Convert categorical columns back to categories
323324
for c in df.columns:
@@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
366367
else:
367368
raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode))
368369

370+
column_name_counts = Counter(df.columns)
371+
if max(column_name_counts.values()) > 1:
372+
raise ValueError('DataFrame contains duplicated column names: ' +
373+
' '.join({k for k, v in column_name_counts.items() if v > 1}))
374+
369375
from root_numpy import array2tree
370376
# We don't want to modify the user's DataFrame here, so we make a shallow copy
371377
df_ = df.copy(deep=False)

root_pandas/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
'version_info',
55
]
66

7-
__version__ = '0.6.1'
7+
__version__ = '0.7.0'
88
version = __version__
99
version_info = tuple(__version__.split('.'))

tests/test_issues.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,20 @@ def test_issue_63():
2424
assert all(len(df) == 1 for df in result)
2525
os.remove('tmp_1.root')
2626
os.remove('tmp_2.root')
27+
28+
29+
def test_issue_80():
30+
df = pd.DataFrame({'a': [1, 2], 'b': [4, 5]})
31+
df.columns = ['a', 'a']
32+
try:
33+
root_pandas.to_root(df, '/tmp/example.root')
34+
except ValueError as e:
35+
assert 'DataFrame contains duplicated column names' in e.args[0]
36+
else:
37+
raise Exception('ValueError is expected')
38+
39+
40+
def test_issue_82():
41+
variables = ['MET_px', 'MET_py', 'EventWeight']
42+
df = root_pandas.read_root('http://scikit-hep.org/uproot/examples/HZZ.root', 'events', columns=variables)
43+
assert list(df.columns) == variables

0 commit comments

Comments
 (0)