Skip to content

Commit 9aeec99

Browse files
jasonjk-parkfacebook-github-bot
authored andcommitted
Benchmark for fbgemm HBC (#791)
Summary: Pull Request resolved: #791 CPU/GPU implementation benchmarking for fbgemm HBC. Reviewed By: jianyuh Differential Revision: D32770579 fbshipit-source-id: e18b12c26afce7eca3eb3836306acc9467e7e9d1
1 parent 810e28a commit 9aeec99

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import time
8+
from typing import Callable, Tuple
9+
10+
import click
11+
import torch
12+
from torch import Tensor
13+
14+
logging.basicConfig(level=logging.DEBUG)
15+
16+
try:
17+
# pyre-ignore[21]
18+
from fbgemm_gpu import open_source # noqa: F401
19+
except Exception:
20+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
21+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
22+
23+
def benchmark_fbgemm_function(
24+
func: Callable[[Tensor], Tuple[Tensor, Tensor]],
25+
input: Tensor,
26+
) -> Tuple[float, Tensor]:
27+
if input.is_cuda:
28+
torch.cuda.synchronize()
29+
start_event = torch.cuda.Event(enable_timing=True)
30+
end_event = torch.cuda.Event(enable_timing=True)
31+
start_event.record()
32+
# Benchmark code
33+
output, _ = func(input)
34+
# Accumulate the time for iters iteration
35+
end_event.record()
36+
torch.cuda.synchronize()
37+
elapsed_time = start_event.elapsed_time(end_event) * 1.0e-3
38+
else:
39+
start_time = time.time()
40+
output, _ = func(input)
41+
elapsed_time = time.time() - start_time
42+
return float(elapsed_time), output
43+
44+
@click.command()
45+
@click.option("--iters", default=100)
46+
@click.option("--warmup-runs", default=2)
47+
def main(
48+
iters: int,
49+
warmup_runs: int,
50+
) -> None:
51+
52+
total_time = {
53+
"fbgemm_cpu_half": 0.0,
54+
"fbgemm_cpu_float": 0.0,
55+
"fbgemm_gpu_half": 0.0,
56+
"fbgemm_gpu_float": 0.0,
57+
}
58+
59+
input_data_cpu = torch.rand(5000, dtype=torch.float)
60+
61+
bin_num_examples: Tensor = torch.empty([5000], dtype=torch.float64).fill_(0.0)
62+
bin_num_positives: Tensor = torch.empty([5000], dtype=torch.float64).fill_(0.0)
63+
lower_bound: float = 0.0
64+
upper_bound: float = 1.0
65+
66+
def fbgemm_hbc_cpu(input: Tensor) -> Tuple[Tensor, Tensor]:
67+
return torch.ops.fbgemm.histogram_binning_calibration(
68+
input, bin_num_examples, bin_num_positives, 0.4, lower_bound,
69+
upper_bound, 0, 0.9995)
70+
71+
for step in range(iters + warmup_runs):
72+
time, _ = benchmark_fbgemm_function(
73+
fbgemm_hbc_cpu,
74+
input_data_cpu.half(),
75+
)
76+
if step >= warmup_runs:
77+
total_time["fbgemm_cpu_half"] += time
78+
79+
time, _ = benchmark_fbgemm_function(
80+
fbgemm_hbc_cpu,
81+
input_data_cpu.float(),
82+
)
83+
if step >= warmup_runs:
84+
total_time["fbgemm_cpu_float"] += time
85+
86+
if torch.cuda.is_available():
87+
bin_num_examples_gpu: Tensor = bin_num_examples.cuda()
88+
bin_num_positives_gpu: Tensor = bin_num_positives.cuda()
89+
90+
def fbgemm_hbc_gpu(input: Tensor) -> Tuple[Tensor, Tensor]:
91+
return torch.ops.fbgemm.histogram_binning_calibration(
92+
input, bin_num_examples_gpu, bin_num_positives_gpu,
93+
0.4, lower_bound, upper_bound, 0, 0.9995)
94+
95+
time, _ = benchmark_fbgemm_function(
96+
fbgemm_hbc_gpu,
97+
input_data_cpu.cuda().half(),
98+
)
99+
if step >= warmup_runs:
100+
total_time["fbgemm_gpu_half"] += time
101+
102+
time, _ = benchmark_fbgemm_function(
103+
fbgemm_hbc_gpu,
104+
input_data_cpu.cuda().float(),
105+
)
106+
if step >= warmup_runs:
107+
total_time["fbgemm_gpu_float"] += time
108+
109+
for k, t_time in total_time.items():
110+
logging.info(
111+
f"{k} time per iter: {t_time / iters * 1.0e6:.0f}us"
112+
)
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)