11
11
12
12
#include < c10/util/CallOnce.h>
13
13
#include < c10/util/StringUtil.h>
14
+ #include < c10/util/env.h>
14
15
15
16
#include < fstream>
16
17
#include < functional>
22
23
#include < unordered_set>
23
24
#include < utility>
24
25
25
- namespace at ::cuda::tunable {
26
-
27
- namespace detail {
28
-
29
- struct MaybeDelete {
30
- bool owns_pointer;
31
- void operator ()(std::ostream* os) const { if (owns_pointer) delete os; }
32
- };
33
-
34
- using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
35
-
36
- inline OstreamPtr get_stream (const std::string& filename) {
37
- if (filename == " out" ) {
38
- return OstreamPtr { &std::cout, MaybeDelete {false } };
39
- }
40
- else if (filename == " err" ) {
41
- return OstreamPtr { &std::cerr, MaybeDelete {false } };
42
- }
43
- else {
44
- return OstreamPtr { new std::ofstream {filename.c_str ()}, MaybeDelete {true } };
45
- }
46
- }
47
-
48
- }
49
-
50
- template <class ... Types>
51
- static void TunableLog (int level, Types... args) {
52
- static const char *env_file = getenv (" PYTORCH_TUNABLEOP_VERBOSE_FILENAME" );
53
- static const char *env_verbose = getenv (" PYTORCH_TUNABLEOP_VERBOSE" );
54
- static int level_user = env_verbose ? atoi (env_verbose) : 0 ;
55
- static auto streamptr = detail::get_stream (env_file ? env_file : " err" );
56
- if (level_user >= level) {
57
- (*streamptr) << c10::str (args...) << std::endl;
58
- }
59
- }
60
- #define TUNABLE_LOGV (LEVEL, ...) TunableLog(LEVEL, __VA_ARGS__)
26
+ #define TUNABLE_LOGV (LEVEL, ...) getTuningContext()->Log (LEVEL, __VA_ARGS__)
61
27
#define TUNABLE_LOG1 (...) TUNABLE_LOGV(1 , __VA_ARGS__)
62
28
#define TUNABLE_LOG2 (...) TUNABLE_LOGV(2 , __VA_ARGS__)
63
29
#define TUNABLE_LOG3 (...) TUNABLE_LOGV(3 , __VA_ARGS__)
64
30
31
+ namespace at ::cuda::tunable {
32
+
65
33
enum TORCH_CUDA_CPP_API TuningStatus {
66
34
OK = 0 ,
67
35
FAIL = 1 ,
@@ -219,7 +187,19 @@ class TORCH_CUDA_CPP_API TuningContext {
219
187
bool ReadFile (const std::string& filename={});
220
188
bool WriteFile (const std::string& filename={});
221
189
190
+ template <class ... Types>
191
+ void Log (int level, Types... args) {
192
+ if (GetLogOkay () && GetLogLevel () >= level) {
193
+ GetLog () << c10::str (args...) << std::endl;
194
+ }
195
+ }
196
+
222
197
private:
198
+ std::string GetLogFilename () const ;
199
+ int GetLogLevel () const ;
200
+ bool GetLogOkay () const ;
201
+ std::ostream& GetLog () const ;
202
+
223
203
bool enable_;
224
204
bool tuning_enable_;
225
205
bool record_untuned_enable_;
@@ -238,6 +218,7 @@ class TORCH_CUDA_CPP_API TuningContext {
238
218
std::string filename_;
239
219
std::ofstream untuned_file_;
240
220
size_t results_count_from_input_file_;
221
+ bool is_shutting_down_;
241
222
};
242
223
243
224
TORCH_CUDA_CPP_API TuningContext* getTuningContext ();
0 commit comments