Skip to content

Commit 55baa9b

Browse files
committed
Add tests and fix for Datetime64TZFormatter
1 parent 9c3312c commit 55baa9b

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

pandas/io/formats/format.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,18 +1526,17 @@ def _format_strings(self) -> List[str]:
15261526
flat_values = DatetimeIndex(flat_values)
15271527

15281528
if self.formatter is not None and callable(self.formatter):
1529-
flat_str_values = np.array([self.formatter(x) for x in flat_values])
1530-
fmt_values = flat_str_values
1529+
fmt_values = [self.formatter(x) for x in flat_values]
15311530
else:
15321531
fmt_values = flat_values._data._format_native_types(
15331532
na_rep=self.nat_rep, date_format=self.date_format
15341533
)
1535-
fmt_values = fmt_values.reshape(values.shape)
15361534

1537-
if len(fmt_values.shape) > 1:
1538-
nested_string_formatter = GenericArrayFormatter(fmt_values)
1539-
fmt_values = nested_string_formatter.get_result()
1540-
else:
1535+
if len(values.shape) > 1:
1536+
fmt_values = np.asarray(fmt_values).reshape(values.shape)
1537+
nested_formatter = GenericArrayFormatter(fmt_values)
1538+
fmt_values = nested_formatter.get_result()
1539+
elif isinstance(fmt_values, np.ndarray):
15411540
fmt_values = fmt_values.tolist()
15421541

15431542
return fmt_values
@@ -1718,18 +1717,16 @@ def _format_strings(self) -> List[str]:
17181717
flat_values = values.ravel()
17191718

17201719
ido = is_dates_only(flat_values)
1721-
17221720
formatter = self.formatter or get_format_datetime64(
17231721
ido, date_format=self.date_format
17241722
)
17251723

1726-
fmt_values = np.array([formatter(x) for x in flat_values]).reshape(values.shape)
1724+
fmt_values = [formatter(x) for x in flat_values]
17271725

1728-
if len(fmt_values.shape) > 1:
1729-
nested_string_formatter = GenericArrayFormatter(fmt_values)
1730-
fmt_values = nested_string_formatter.get_result()
1731-
else:
1732-
fmt_values = fmt_values.tolist()
1726+
if len(values.shape) > 1:
1727+
fmt_values = np.asarray(fmt_values).reshape(values.shape)
1728+
nested_formatter = GenericArrayFormatter(fmt_values)
1729+
fmt_values = nested_formatter.get_result()
17331730

17341731
return fmt_values
17351732

pandas/tests/io/formats/test_format.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3146,7 +3146,13 @@ def format_func(t):
31463146
class TestDatetime64TZFormatter:
31473147
def test_mixed(self):
31483148
utc = dateutil.tz.tzutc()
3149-
x = Series([datetime(2013, 1, 1, tzinfo=utc), datetime(2013, 1, 1, 12, tzinfo=utc), pd.NaT])
3149+
x = Series(
3150+
[
3151+
datetime(2013, 1, 1, tzinfo=utc),
3152+
datetime(2013, 1, 1, 12, tzinfo=utc),
3153+
pd.NaT,
3154+
]
3155+
)
31503156
result = fmt.Datetime64TZFormatter(x).get_result()
31513157
assert len(result) == 3
31523158
assert result[0].strip() == "2013-01-01 00:00:00+00:00"
@@ -3163,20 +3169,26 @@ def test_datetime64formatter_1d_array(self):
31633169
assert result[2].strip() == "2018-01-01 02:00:00-08:00"
31643170

31653171
def test_datetime64formatter_2d_array(self):
3166-
x = pd.date_range("2018-01-01", periods=10, freq="H", tz="US/Pacific").to_numpy()
3172+
x = pd.date_range(
3173+
"2018-01-01", periods=10, freq="H", tz="US/Pacific"
3174+
).to_numpy()
31673175
formatter = fmt.Datetime64TZFormatter(x.reshape((5, 2)))
31683176
result = formatter.get_result()
31693177
assert len(result) == 5
31703178
assert result[0].strip() == "[2018-01-01 00:00:00-08:00, 2018-01-01 01:00:0..."
31713179
assert result[4].strip() == "[2018-01-01 08:00:00-08:00, 2018-01-01 09:00:0..."
31723180

31733181
def test_datetime64formatter_2d_array_format_func(self):
3174-
x = pd.date_range("2018-01-01", periods=16, freq="H", tz="US/Pacific").to_numpy()
3182+
x = pd.date_range(
3183+
"2018-01-01", periods=16, freq="H", tz="US/Pacific"
3184+
).to_numpy()
31753185

31763186
def format_func(t):
31773187
return t.strftime("%H-%m %Z")
31783188

3179-
formatter = fmt.Datetime64TZFormatter(x.reshape((4, 2, 2)), formatter=format_func)
3189+
formatter = fmt.Datetime64TZFormatter(
3190+
x.reshape((4, 2, 2)), formatter=format_func
3191+
)
31803192
result = formatter.get_result()
31813193
assert len(result) == 4
31823194
assert result[0].strip() == "[[00-01 PST, 01-01 PST], [02-01 PST, 03-01 PST]]"

0 commit comments

Comments
 (0)