Skip to content

Commit 81e20d3

Browse files
gdankelfacebook-github-bot
authored andcommitted
Log device properties to trace file
Summary: Device properties are useful in any case, but especially for performing analysis on traces such as occupancy. This patch is a re-implementation of pytorch#209 Reviewed By: ilia-cher Differential Revision: D28337067 fbshipit-source-id: b61588b414a9faa6f697260b48f750f013ba553b
1 parent 1dbf6e8 commit 81e20d3

File tree

4 files changed

+49
-14
lines changed

4 files changed

+49
-14
lines changed

libkineto/src/CudaDeviceProperties.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "CudaDeviceProperties.h"
99

10+
#include <fmt/format.h>
1011
#include <vector>
1112

1213
#include <cuda_runtime.h>
@@ -43,6 +44,36 @@ const std::vector<cudaOccDeviceProp>& occDeviceProps() {
4344
return occProps;
4445
}
4546

47+
static const std::string createComputePropertiesJson(
48+
const cudaOccDeviceProp& props) {
49+
return fmt::format(R"JSON(
50+
{{
51+
"computeMajor": {}, "computeMinor": {},
52+
"maxThreadsPerBlock": {}, "maxThreadsPerMultiprocessor": {},
53+
"regsPerBlock": {}, "regsPerMultiprocessor": {}, "warpSize": {},
54+
"sharedMemPerBlock": {}, "sharedMemPerMultiprocessor": {},
55+
"numSms": {}, "sharedMemPerBlockOptin": {}
56+
}})JSON",
57+
props.computeMajor, props.computeMinor,
58+
props.maxThreadsPerBlock, props.maxThreadsPerMultiprocessor,
59+
props.regsPerBlock, props.regsPerMultiprocessor, props.warpSize,
60+
props.sharedMemPerBlock, props.sharedMemPerMultiprocessor,
61+
props.numSms, props.sharedMemPerBlockOptin);
62+
}
63+
64+
static const std::string createComputePropertiesJson() {
65+
std::vector<std::string> computeProps;
66+
for (const auto& props : occDeviceProps()) {
67+
computeProps.push_back(createComputePropertiesJson(props));
68+
}
69+
return fmt::format("{}", fmt::join(computeProps, ",\n"));
70+
}
71+
72+
const std::string& computePropertiesJson() {
73+
static std::string computePropsJson = createComputePropertiesJson();
74+
return computePropsJson;
75+
}
76+
4677
float kernelOccupancy(
4778
uint32_t deviceId,
4879
uint16_t registersPerThread,

libkineto/src/CudaDeviceProperties.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
#pragma once
99

1010
#include <stdint.h>
11+
#include <string>
1112

1213
namespace KINETO_NAMESPACE {
1314

15+
// Return estimated achieved occupancy for a kernel
1416
float kernelOccupancy(
1517
uint32_t deviceId,
1618
uint16_t registersPerThread,
@@ -21,4 +23,7 @@ float kernelOccupancy(
2123
int32_t blockZ,
2224
float blocks_per_sm);
2325

26+
// Return compute properties for each device as a json string
27+
const std::string& computePropertiesJson();
28+
2429
} // namespace KINETO_NAMESPACE

libkineto/src/Logger.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ Logger::~Logger() {
6161
buf_ << " : " << strerror_r(errnum_, buf, sizeof(buf));
6262
}
6363
#endif
64-
buf_ << std::ends;
65-
out_ << buf_.str() << std::endl;
64+
buf_ << std::endl;
65+
out_ << buf_.str();
6666
}
6767

6868
void Logger::setVerboseLogModules(const std::vector<std::string>& modules) {

libkineto/src/output_json.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,21 @@ void ChromeTraceLogger::handleTraceStart(
3939
)JSON", kSchemaVersion);
4040

4141
if (!metadata.empty()) {
42-
traceOf_ << R"JSON(
43-
"metadata": {
44-
)JSON";
45-
bool first = true;
42+
std::vector<std::string> vals;
4643
for (const auto& kv : metadata) {
47-
if (!first) {
48-
traceOf_ << ",\n";
49-
}
50-
traceOf_ << fmt::format(R"( "{}": "{}")", kv.first, kv.second);
51-
first = false;
44+
vals.push_back(fmt::format(R"( "{}": "{}")", kv.first, kv.second));
5245
}
53-
traceOf_ << R"JSON(
54-
},
55-
)JSON";
46+
traceOf_ << fmt::format(R"JSON(
47+
"metadata": {{
48+
{}
49+
}},)JSON", fmt::join(vals, ",\n"));
5650
}
5751

52+
traceOf_ << fmt::format(R"JSON(
53+
"computeProperties": [
54+
{}
55+
],)JSON", computePropertiesJson());
56+
5857
traceOf_ << R"JSON(
5958
"traceEvents": [
6059
)JSON";

0 commit comments

Comments
 (0)