5
5
6
6
from distutils .version import LooseVersion
7
7
import numpy as np
8
+ import pytz
8
9
import pandas as pd
9
10
10
11
from xray import Variable , Dataset , DataArray
@@ -153,7 +154,7 @@ def test_0d_time_data(self):
153
154
def test_datetime64_conversion (self ):
154
155
times = pd .date_range ('2000-01-01' , periods = 3 )
155
156
for values , preserve_source in [
156
- (times , False ),
157
+ (times , True ),
157
158
(times .values , True ),
158
159
(times .values .astype ('datetime64[s]' ), False ),
159
160
(times .to_pydatetime (), False ),
@@ -163,15 +164,12 @@ def test_datetime64_conversion(self):
163
164
self .assertArrayEqual (v .values , times .values )
164
165
self .assertEqual (v .values .dtype , np .dtype ('datetime64[ns]' ))
165
166
same_source = source_ndarray (v .values ) is source_ndarray (values )
166
- if preserve_source and self .cls is Variable :
167
- self .assertTrue (same_source )
168
- else :
169
- self .assertFalse (same_source )
167
+ assert preserve_source == same_source
170
168
171
169
def test_timedelta64_conversion (self ):
172
170
times = pd .timedelta_range (start = 0 , periods = 3 )
173
171
for values , preserve_source in [
174
- (times , False ),
172
+ (times , True ),
175
173
(times .values , True ),
176
174
(times .values .astype ('timedelta64[s]' ), False ),
177
175
(times .to_pytimedelta (), False ),
@@ -181,10 +179,7 @@ def test_timedelta64_conversion(self):
181
179
self .assertArrayEqual (v .values , times .values )
182
180
self .assertEqual (v .values .dtype , np .dtype ('timedelta64[ns]' ))
183
181
same_source = source_ndarray (v .values ) is source_ndarray (values )
184
- if preserve_source and self .cls is Variable :
185
- self .assertTrue (same_source )
186
- else :
187
- self .assertFalse (same_source )
182
+ assert preserve_source == same_source
188
183
189
184
def test_object_conversion (self ):
190
185
data = np .arange (5 ).astype (str ).astype (object )
@@ -405,6 +400,22 @@ def test_aggregate_complex(self):
405
400
expected = Variable ((), 0.5 + 1j )
406
401
self .assertVariableAllClose (v .mean (), expected )
407
402
403
+ def test_pandas_cateogrical_dtype (self ):
404
+ data = pd .Categorical (np .arange (10 , dtype = 'int64' ))
405
+ v = self .cls ('x' , data )
406
+ print (v ) # should not error
407
+ assert v .dtype == 'int64'
408
+
409
+ def test_pandas_datetime64_with_tz (self ):
410
+ data = pd .date_range (start = '2000-01-01' ,
411
+ tz = pytz .timezone ('America/New_York' ),
412
+ periods = 10 , freq = '1h' )
413
+ v = self .cls ('x' , data )
414
+ print (v ) # should not error
415
+ if 'America/New_York' in data .dtype :
416
+ # pandas is new enough that it has datetime64 with timezone dtype
417
+ assert v .dtype == 'object'
418
+
408
419
409
420
class TestVariable (TestCase , VariableSubclassTestCases ):
410
421
cls = staticmethod (Variable )
0 commit comments