|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import logging |
| 7 | +import time |
| 8 | +from typing import Callable, Tuple |
| 9 | + |
| 10 | +import click |
| 11 | +import torch |
| 12 | +from torch import Tensor |
| 13 | + |
| 14 | +logging.basicConfig(level=logging.DEBUG) |
| 15 | + |
| 16 | +try: |
| 17 | + # pyre-ignore[21] |
| 18 | + from fbgemm_gpu import open_source # noqa: F401 |
| 19 | +except Exception: |
| 20 | + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") |
| 21 | + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") |
| 22 | + |
| 23 | +def benchmark_fbgemm_function( |
| 24 | + func: Callable[[Tensor], Tuple[Tensor, Tensor]], |
| 25 | + input: Tensor, |
| 26 | + ) -> Tuple[float, Tensor]: |
| 27 | + if input.is_cuda: |
| 28 | + torch.cuda.synchronize() |
| 29 | + start_event = torch.cuda.Event(enable_timing=True) |
| 30 | + end_event = torch.cuda.Event(enable_timing=True) |
| 31 | + start_event.record() |
| 32 | + # Benchmark code |
| 33 | + output, _ = func(input) |
| 34 | + # Accumulate the time for iters iteration |
| 35 | + end_event.record() |
| 36 | + torch.cuda.synchronize() |
| 37 | + elapsed_time = start_event.elapsed_time(end_event) * 1.0e-3 |
| 38 | + else: |
| 39 | + start_time = time.time() |
| 40 | + output, _ = func(input) |
| 41 | + elapsed_time = time.time() - start_time |
| 42 | + return float(elapsed_time), output |
| 43 | + |
| 44 | +@click.command() |
| 45 | +@click.option("--iters", default=100) |
| 46 | +@click.option("--warmup-runs", default=2) |
| 47 | +def main( |
| 48 | + iters: int, |
| 49 | + warmup_runs: int, |
| 50 | +) -> None: |
| 51 | + |
| 52 | + total_time = { |
| 53 | + "fbgemm_cpu_half": 0.0, |
| 54 | + "fbgemm_cpu_float": 0.0, |
| 55 | + "fbgemm_gpu_half": 0.0, |
| 56 | + "fbgemm_gpu_float": 0.0, |
| 57 | + } |
| 58 | + |
| 59 | + input_data_cpu = torch.rand(5000, dtype=torch.float) |
| 60 | + |
| 61 | + bin_num_examples: Tensor = torch.empty([5000], dtype=torch.float64).fill_(0.0) |
| 62 | + bin_num_positives: Tensor = torch.empty([5000], dtype=torch.float64).fill_(0.0) |
| 63 | + lower_bound: float = 0.0 |
| 64 | + upper_bound: float = 1.0 |
| 65 | + |
| 66 | + def fbgemm_hbc_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: |
| 67 | + return torch.ops.fbgemm.histogram_binning_calibration( |
| 68 | + input, bin_num_examples, bin_num_positives, 0.4, lower_bound, |
| 69 | + upper_bound, 0, 0.9995) |
| 70 | + |
| 71 | + for step in range(iters + warmup_runs): |
| 72 | + time, _ = benchmark_fbgemm_function( |
| 73 | + fbgemm_hbc_cpu, |
| 74 | + input_data_cpu.half(), |
| 75 | + ) |
| 76 | + if step >= warmup_runs: |
| 77 | + total_time["fbgemm_cpu_half"] += time |
| 78 | + |
| 79 | + time, _ = benchmark_fbgemm_function( |
| 80 | + fbgemm_hbc_cpu, |
| 81 | + input_data_cpu.float(), |
| 82 | + ) |
| 83 | + if step >= warmup_runs: |
| 84 | + total_time["fbgemm_cpu_float"] += time |
| 85 | + |
| 86 | + if torch.cuda.is_available(): |
| 87 | + bin_num_examples_gpu: Tensor = bin_num_examples.cuda() |
| 88 | + bin_num_positives_gpu: Tensor = bin_num_positives.cuda() |
| 89 | + |
| 90 | + def fbgemm_hbc_gpu(input: Tensor) -> Tuple[Tensor, Tensor]: |
| 91 | + return torch.ops.fbgemm.histogram_binning_calibration( |
| 92 | + input, bin_num_examples_gpu, bin_num_positives_gpu, |
| 93 | + 0.4, lower_bound, upper_bound, 0, 0.9995) |
| 94 | + |
| 95 | + time, _ = benchmark_fbgemm_function( |
| 96 | + fbgemm_hbc_gpu, |
| 97 | + input_data_cpu.cuda().half(), |
| 98 | + ) |
| 99 | + if step >= warmup_runs: |
| 100 | + total_time["fbgemm_gpu_half"] += time |
| 101 | + |
| 102 | + time, _ = benchmark_fbgemm_function( |
| 103 | + fbgemm_hbc_gpu, |
| 104 | + input_data_cpu.cuda().float(), |
| 105 | + ) |
| 106 | + if step >= warmup_runs: |
| 107 | + total_time["fbgemm_gpu_float"] += time |
| 108 | + |
| 109 | + for k, t_time in total_time.items(): |
| 110 | + logging.info( |
| 111 | + f"{k} time per iter: {t_time / iters * 1.0e6:.0f}us" |
| 112 | + ) |
| 113 | + |
| 114 | +if __name__ == "__main__": |
| 115 | + main() |
0 commit comments