Skip to content

Commit 07344e6

Browse files
gaoteng-gitfacebook-github-bot
authored andcommitted
add 2 kernel metrics (#185)
Summary: 1. Update "warps per sm" to "blocks per sm". 2. Add "occupancy" per kernel. Pull Request resolved: #185 Reviewed By: chaekit Differential Revision: D28120846 Pulled By: gdankel fbshipit-source-id: c7ce33b1421b60ae4323c66d38bba5eb175b105b
1 parent a409be7 commit 07344e6

File tree

4 files changed

+137
-5
lines changed

4 files changed

+137
-5
lines changed

libkineto/libkineto_defs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def get_libkineto_srcs(with_api = True):
1717
"src/ActivityProfilerProxy.cpp",
1818
"src/Config.cpp",
1919
"src/ConfigLoader.cpp",
20+
"src/CudaDeviceProperties.cpp",
2021
"src/CuptiActivityInterface.cpp",
2122
"src/CuptiEventInterface.cpp",
2223
"src/CuptiMetricInterface.cpp",
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Kineto Contributors
3+
* All rights reserved.
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "CudaDeviceProperties.h"
9+
10+
#include <vector>
11+
12+
#include <cuda_runtime.h>
13+
#include <cuda_occupancy.h>
14+
15+
#include "Logger.h"
16+
17+
namespace KINETO_NAMESPACE {
18+
19+
std::vector<cudaOccDeviceProp> createOccDeviceProps() {
20+
std::vector<cudaOccDeviceProp> occProps;
21+
int device_count;
22+
cudaError_t error_id = cudaGetDeviceCount(&device_count);
23+
// Return empty vector if error.
24+
if (error_id != cudaSuccess) {
25+
return occProps;
26+
}
27+
for (int i = 0; i < device_count; ++i) {
28+
cudaDeviceProp prop;
29+
error_id = cudaGetDeviceProperties(&prop, i);
30+
// Return empty vector if any device property fail to get.
31+
if (error_id != cudaSuccess) {
32+
return occProps;
33+
}
34+
cudaOccDeviceProp occProp;
35+
occProp = prop;
36+
occProps.push_back(occProp);
37+
}
38+
return occProps;
39+
}
40+
41+
const std::vector<cudaOccDeviceProp>& occDeviceProps() {
42+
static std::vector<cudaOccDeviceProp> occProps = createOccDeviceProps();
43+
return occProps;
44+
}
45+
46+
float kernelOccupancy(
47+
uint32_t deviceId,
48+
uint16_t registersPerThread,
49+
int32_t staticSharedMemory,
50+
int32_t dynamicSharedMemory,
51+
int32_t blockX,
52+
int32_t blockY,
53+
int32_t blockZ,
54+
float blocksPerSm) {
55+
// Calculate occupancy
56+
float occupancy = -1.0;
57+
const std::vector<cudaOccDeviceProp> &occProps = occDeviceProps();
58+
if (deviceId < occProps.size()) {
59+
cudaOccFuncAttributes occFuncAttr;
60+
occFuncAttr.maxThreadsPerBlock = INT_MAX;
61+
occFuncAttr.numRegs = registersPerThread;
62+
occFuncAttr.sharedSizeBytes = staticSharedMemory;
63+
occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF;
64+
occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
65+
occFuncAttr.maxDynamicSharedSizeBytes = 0;
66+
const cudaOccDeviceState occDeviceState = {};
67+
int blockSize = blockX * blockY * blockZ;
68+
size_t dynamicSmemSize = dynamicSharedMemory;
69+
cudaOccResult occ_result;
70+
cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor(
71+
&occ_result, &occProps[deviceId], &occFuncAttr, &occDeviceState,
72+
blockSize, dynamicSmemSize);
73+
if (status == CUDA_OCC_SUCCESS) {
74+
if (occ_result.activeBlocksPerMultiprocessor < blocksPerSm) {
75+
blocksPerSm = occ_result.activeBlocksPerMultiprocessor;
76+
}
77+
occupancy = blocksPerSm * blockSize /
78+
(float) occProps[deviceId].maxThreadsPerMultiprocessor;
79+
} else {
80+
LOG_EVERY_N(ERROR, 1000) << "Failed to calculate occupancy, status = "
81+
<< status;
82+
}
83+
}
84+
return occupancy;
85+
}
86+
87+
} // namespace KINETO_NAMESPACE
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Kineto Contributors
3+
* All rights reserved.
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <stdint.h>
11+
12+
namespace KINETO_NAMESPACE {
13+
14+
float kernelOccupancy(
15+
uint32_t deviceId,
16+
uint16_t registersPerThread,
17+
int32_t staticSharedMemory,
18+
int32_t dynamicSharedMemory,
19+
int32_t blockX,
20+
int32_t blockY,
21+
int32_t blockZ,
22+
float blocks_per_sm);
23+
24+
} // namespace KINETO_NAMESPACE

libkineto/src/output_json.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "CuptiActivity.h"
1818
#include "CuptiActivity.tpp"
1919
#include "CuptiActivityInterface.h"
20+
#include "CudaDeviceProperties.h"
2021
#endif // HAS_CUPTI
2122
#include "Demangle.h"
2223
#include "TraceSpan.h"
@@ -307,12 +308,27 @@ void ChromeTraceLogger::handleGpuActivity(
307308
const CUpti_ActivityKernel4* kernel = &activity.raw();
308309
const TraceActivity& ext = *activity.linkedActivity();
309310
constexpr int threads_per_warp = 32;
311+
float blocks_per_sm = -1.0;
310312
float warps_per_sm = -1.0;
311313
if (smCount_) {
312-
warps_per_sm = (kernel->gridX * kernel->gridY * kernel->gridZ) *
313-
(kernel->blockX * kernel->blockY * kernel->blockZ) /
314-
(float) threads_per_warp / smCount_;
314+
blocks_per_sm =
315+
(kernel->gridX * kernel->gridY * kernel->gridZ) / (float) smCount_;
316+
warps_per_sm =
317+
blocks_per_sm * (kernel->blockX * kernel->blockY * kernel->blockZ)
318+
/ threads_per_warp;
315319
}
320+
321+
// Calculate occupancy
322+
float occupancy = KINETO_NAMESPACE::kernelOccupancy(
323+
kernel->deviceId,
324+
kernel->registersPerThread,
325+
kernel->staticSharedMemory,
326+
kernel->dynamicSharedMemory,
327+
kernel->blockX,
328+
kernel->blockY,
329+
kernel->blockZ,
330+
blocks_per_sm);
331+
316332
// clang-format off
317333
traceOf_ << fmt::format(R"JSON(
318334
{{
@@ -322,9 +338,11 @@ void ChromeTraceLogger::handleGpuActivity(
322338
"stream": {}, "correlation": {}, "external id": {},
323339
"registers per thread": {},
324340
"shared memory": {},
341+
"blocks per SM": {},
325342
"warps per SM": {},
326343
"grid": [{}, {}, {}],
327-
"block": [{}, {}, {}]
344+
"block": [{}, {}, {}],
345+
"theoretical occupancy %": {}
328346
}}
329347
}},)JSON",
330348
traceActivityJson(activity, "stream "),
@@ -333,9 +351,11 @@ void ChromeTraceLogger::handleGpuActivity(
333351
kernel->streamId, kernel->correlationId, ext.correlationId(),
334352
kernel->registersPerThread,
335353
kernel->staticSharedMemory + kernel->dynamicSharedMemory,
354+
blocks_per_sm,
336355
warps_per_sm,
337356
kernel->gridX, kernel->gridY, kernel->gridZ,
338-
kernel->blockX, kernel->blockY, kernel->blockZ);
357+
kernel->blockX, kernel->blockY, kernel->blockZ,
358+
(int) (0.5 + occupancy * 100.0));
339359
// clang-format on
340360

341361
handleLinkEnd(activity);

0 commit comments

Comments
 (0)