4
4
5
5
from pytorch_lightning .profiler import Profiler , AdvancedProfiler
6
6
7
- PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
7
+ PROFILER_OVERHEAD_MAX_TOLERANCE = 0.001
8
8
9
9
10
10
@pytest .fixture
11
11
def simple_profiler ():
12
+ """Creates a new profiler for every test with `simple_profiler` as an arg."""
12
13
profiler = Profiler ()
13
14
return profiler
14
15
15
16
16
17
@pytest .fixture
17
18
def advanced_profiler ():
19
+ """Creates a new profiler for every test with `advanced_profiler` as an arg."""
18
20
profiler = AdvancedProfiler ()
19
21
return profiler
20
22
21
23
22
24
@pytest .mark .parametrize ("action,expected" , [("a" , [3 , 1 ]), ("b" , [2 ]), ("c" , [1 ])])
23
25
def test_simple_profiler_durations (simple_profiler , action , expected ):
24
- """
25
- ensure the reported durations are reasonably accurate
26
- """
26
+ """Ensure the reported durations are reasonably accurate."""
27
27
28
28
for duration in expected :
29
29
with simple_profiler .profile (action ):
@@ -37,9 +37,7 @@ def test_simple_profiler_durations(simple_profiler, action, expected):
37
37
38
38
39
39
def test_simple_profiler_overhead (simple_profiler , n_iter = 5 ):
40
- """
41
- ensure that the profiler doesn't introduce too much overhead during training
42
- """
40
+ """Ensure that the profiler doesn't introduce too much overhead during training."""
43
41
for _ in range (n_iter ):
44
42
with simple_profiler .profile ("no-op" ):
45
43
pass
@@ -49,24 +47,25 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5):
49
47
50
48
51
49
def test_simple_profiler_describe (simple_profiler ):
52
- """
53
- ensure the profiler won't fail when reporting the summary
54
- """
50
+ """Ensure the profiler won't fail when reporting the summary."""
55
51
simple_profiler .describe ()
56
52
57
53
54
+ def _get_total_cprofile_duration (profile ):
55
+ return sum ([x .totaltime for x in profile .getstats ()])
56
+
57
+
58
58
@pytest .mark .parametrize ("action,expected" , [("a" , [3 , 1 ]), ("b" , [2 ]), ("c" , [1 ])])
59
59
def test_advanced_profiler_durations (advanced_profiler , action , expected ):
60
- def _get_total_duration (profile ):
61
- return sum ([x .totaltime for x in profile .getstats ()])
60
+ """Ensure the reported durations are reasonably accurate."""
62
61
63
62
for duration in expected :
64
63
with advanced_profiler .profile (action ):
65
64
time .sleep (duration )
66
65
67
66
# different environments have different precision when it comes to time.sleep()
68
67
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
69
- recored_total_duration = _get_total_duration (
68
+ recored_total_duration = _get_total_cprofile_duration (
70
69
advanced_profiler .profiled_actions [action ]
71
70
)
72
71
expected_total_duration = np .sum (expected )
@@ -76,21 +75,17 @@ def _get_total_duration(profile):
76
75
77
76
78
77
def test_advanced_profiler_overhead (advanced_profiler , n_iter = 5 ):
79
- """
80
- ensure that the profiler doesn't introduce too much overhead during training
81
- """
78
+ """Ensure that the profiler doesn't introduce too much overhead during training."""
82
79
for _ in range (n_iter ):
83
80
with advanced_profiler .profile ("no-op" ):
84
81
pass
85
82
86
83
action_profile = advanced_profiler .profiled_actions ["no-op" ]
87
- total_duration = sum ([ x . totaltime for x in action_profile . getstats ()] )
84
+ total_duration = _get_total_cprofile_duration ( action_profile )
88
85
average_duration = total_duration / n_iter
89
86
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
90
87
91
88
92
89
def test_advanced_profiler_describe (advanced_profiler ):
93
- """
94
- ensure the profiler won't fail when reporting the summary
95
- """
90
+ """Ensure the profiler won't fail when reporting the summary."""
96
91
advanced_profiler .describe ()
0 commit comments