Skip to content

Commit ea8878b

Browse files
authored
clean up tests/test_profiler.py (#867)
* cleanup docstrings, _get_total_cprofile_duration in module * relax profiler overhead tolerance
1 parent c58aab0 commit ea8878b

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

tests/test_profiler.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@
44

55
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
66

7-
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
7+
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.001
88

99

1010
@pytest.fixture
1111
def simple_profiler():
12+
"""Creates a new profiler for every test with `simple_profiler` as an arg."""
1213
profiler = Profiler()
1314
return profiler
1415

1516

1617
@pytest.fixture
1718
def advanced_profiler():
19+
"""Creates a new profiler for every test with `advanced_profiler` as an arg."""
1820
profiler = AdvancedProfiler()
1921
return profiler
2022

2123

2224
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
2325
def test_simple_profiler_durations(simple_profiler, action, expected):
24-
"""
25-
ensure the reported durations are reasonably accurate
26-
"""
26+
"""Ensure the reported durations are reasonably accurate."""
2727

2828
for duration in expected:
2929
with simple_profiler.profile(action):
@@ -37,9 +37,7 @@ def test_simple_profiler_durations(simple_profiler, action, expected):
3737

3838

3939
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
40-
"""
41-
ensure that the profiler doesn't introduce too much overhead during training
42-
"""
40+
"""Ensure that the profiler doesn't introduce too much overhead during training."""
4341
for _ in range(n_iter):
4442
with simple_profiler.profile("no-op"):
4543
pass
@@ -49,24 +47,25 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5):
4947

5048

5149
def test_simple_profiler_describe(simple_profiler):
52-
"""
53-
ensure the profiler won't fail when reporting the summary
54-
"""
50+
"""Ensure the profiler won't fail when reporting the summary."""
5551
simple_profiler.describe()
5652

5753

54+
def _get_total_cprofile_duration(profile):
55+
return sum([x.totaltime for x in profile.getstats()])
56+
57+
5858
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
5959
def test_advanced_profiler_durations(advanced_profiler, action, expected):
60-
def _get_total_duration(profile):
61-
return sum([x.totaltime for x in profile.getstats()])
60+
"""Ensure the reported durations are reasonably accurate."""
6261

6362
for duration in expected:
6463
with advanced_profiler.profile(action):
6564
time.sleep(duration)
6665

6766
# different environments have different precision when it comes to time.sleep()
6867
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
69-
recored_total_duration = _get_total_duration(
68+
recored_total_duration = _get_total_cprofile_duration(
7069
advanced_profiler.profiled_actions[action]
7170
)
7271
expected_total_duration = np.sum(expected)
@@ -76,21 +75,17 @@ def _get_total_duration(profile):
7675

7776

7877
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
79-
"""
80-
ensure that the profiler doesn't introduce too much overhead during training
81-
"""
78+
"""Ensure that the profiler doesn't introduce too much overhead during training."""
8279
for _ in range(n_iter):
8380
with advanced_profiler.profile("no-op"):
8481
pass
8582

8683
action_profile = advanced_profiler.profiled_actions["no-op"]
87-
total_duration = sum([x.totaltime for x in action_profile.getstats()])
84+
total_duration = _get_total_cprofile_duration(action_profile)
8885
average_duration = total_duration / n_iter
8986
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
9087

9188

9289
def test_advanced_profiler_describe(advanced_profiler):
93-
"""
94-
ensure the profiler won't fail when reporting the summary
95-
"""
90+
"""Ensure the profiler won't fail when reporting the summary."""
9691
advanced_profiler.describe()

0 commit comments

Comments
 (0)