Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions tmp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
## 1. Summary
This PR is a basic POC of how to reduce triton kernel launching overhead. Previous description can be found in https://github.com/liboyue/triton/tree/low-latency-jit-function/tmp/README.md

Previously I though Python is the bottleneck. But new experiments show that a pure Python frontend can still be very efficient. I noticed `triton` `main` branch has a launching overhead regression.

## 2. Sources of launching overhead
To be fair, triton runtime does not incur a very bad overhead. However, when inputs are small, like inferencing with kv-cache, a kernel can finish before launching the next kernel. What's worse is the overhead is proportional to number of arguments a kernel has (see benchmarks below).

### 2.1. `JITFunction.run`

One source of overhead is `JITFunction.run` method. It creates a lot of python containers, and has some very expensive function calls:
https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L401-L402

https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L416-L419

What's worse is the overhead is proportional to number of parameters of a kernel.

### 2.2. `CompiledKernel.runner`
I suspect the other major source of overhead is from calling `CompiledKernel`.

https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/compiler/compiler.py#L352-L359



## 3. Benchmarks
The benchmarks measures run time for 3 `JITFunction.run` implementations:
- default: `triton`'s implementation.
- python: an optimized python implementation.
- cpp: an optimized C++ implementation (not updated because Python is good enough).

Two additional functions are also measured:
- noop: an empty op, to measure the resolution of benchmarks
- kernel: the actual CUDA kernel with c wrapper. This is as close as possilbe to run bare kernels.

Two running modes are measured:
- warmup: calling kernels with `kernel[grid](..., warmup=True)`, which is the Python overhead in `JITFunction.run()` function.
- (empty): calling kernels normally.

Environment:
- CPU: Xeon 6154 @ 3.00GHz
- GPU: RTX 2080 SUPER
- CUDA: 12.1
- PyTorch: 2.2.2

Figures are run time for different input lengths.

**NOTE: the comparison is not fair, because I set the `device_type` to `"cuda"` manually, so my code is definitely much much faster.**
### 3.1. Triton 2.0

Kernel launching overhead in us (data from 33 runs)

| | default_short_warmup | default_long_warmup | python_short_warmup | python_long_warmup | cpp_short_warmup | cpp_long_warmup |
|:------|-----------------------:|----------------------:|----------------------:|---------------------:|-------------------:|------------------:|
| mean | 38.5711 | 61.3383 | 7.01908 | 10.0304 | 7.45927 | 10.2434 |
| std | 0.117177 | 0.636091 | 0.17726 | 0.0643082 | 0.0376316 | 0.0715353 |

Kernel run time vs. input length
![kernel_time_triton_2](https://github.com/openai/triton/assets/5857249/cc4b57aa-de0a-4bd4-9ee4-626beec3f93c)
For some reasons the Python implementation is faster than the C++ implementation. But I didn't have time to figure out -- Python is fast enough.

Now the total overhead is around 15us. `JITFunction.run()` costs around 7us, then `CompiledKernel.runner` probably costs another 7us.

### 3.2. Triton 3.0

Kernel launching overhead in us (data from 33 runs)

| | default_short_warmup | default_long_warmup | python_short_warmup | python_long_warmup |
|:------|-----------------------:|----------------------:|----------------------:|---------------------:|
| mean | 101.073 | 158.28 | 98.1783 | 154.857 |
| std | 0.596246 | 1.00237 | 0.480793 | 1.11324 |

Kernel run time vs. input length
![kernel_time_triton_3](https://github.com/openai/triton/assets/5857249/73c53ca2-8aeb-431d-93f1-d7b260dc2a64)
We can observe the launching overhead regression.

## 4. Proposed solutions
I believe the kernel launching overhead can be further reduced with the following simple optimizations.

### 4.1. Stronger assumptions on devices

It is very expensive to figure out which the `device_type` should be.

For example, I guess it is ok for `triton` to assume no one will have NVIDIA and AMD GPUs on the same machine. Then, the `device_type` can be cached at `triton`'s initialization time: if there is an NVIDIA GPU then `"cuda"` else ... (idk which types are supported). Although this is not a future-proof solution, I believe it is reasonable to make some strong assumptions for now.

### 4.2. Dynamically generate `run()` and `runner()`
It is very expensive to call `signature.bind()` and pack call args. Generating these functions at `jit` time can eliminate these expensive calls which can save a good amount of time.

For example, define a kernel as
```
@triton.jit
def kernel(
a,
b: float,
c: tl.int,
d: tl.tensor,
e: tl.tensor[tl.float32],
NUM_BLOCKS: tl.constexpr[tl.int] = 10
):
...
```

The generated `run()` function's signature can be (type hints are useless here so omitted)
```
def run(
self,
a,
b,
c,
d,
e,
NUM_BLOCKS=10,
*,
grid=None,
num_warps=None,
num_ctas=1,
num_stages=None,
enable_warp_specialization=False,
enable_fp_fusion=True,
extern_libs=None,
stream=None,
warmup=False,
device=None,
device_type=None
):

assert type(b) == float
assert torch.is_tensor(d)
assert torch.is_tensor(e)
assert e.dtype == torch.float32
```

In this way, Python parses params and sets default values, which is much faster than `signature.bind()`.

Furthermore, `sig_key`, `constexpr_key`, `spec_key`, etc, can all be written explicitly as tuples of `run()`'s arguments. `c_wrapper`'s args can also be "hard-coded" in the same way.

### 4.3. Improving type hints

With improved type hints, kernel definitions are more informative, so that the generated `run()` functions can rely less on Python runtime and even perform type checks. This reduces overhead and provides some more safety (maybe?).

(see previous subsection for the example)

12 changes: 12 additions & 0 deletions tmp/data/kernel_overhead_triton_2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
m,default_short_warmup,default_long_warmup,python_generated_short_warmup,python_generated_long_warmup,python_generated_with_type_short_warmup,python_generated_with_type_long_warmup,python_short_warmup,python_long_warmup
128,40.61387939453125,63.137384033203126,4.265983963012696,6.051993560791016,3.542204666137695,3.8160350799560545,14.097975158691407,24.697056579589844
256,40.33582458496094,62.78548583984375,4.227916717529297,6.16087989807129,3.5885665893554686,3.8557502746582033,13.756002807617188,24.189677429199218
384,40.42800598144531,62.96868286132813,4.2608798980712885,6.093471908569335,3.5788864135742187,3.8517887115478517,13.77624053955078,24.302566528320312
512,40.319558715820314,62.63114013671875,4.218121719360352,6.069235229492188,3.5364864349365237,3.8576126098632812,13.743682861328125,24.280221557617185
640,40.34354858398437,62.828546142578126,4.210825729370117,6.043830490112305,3.5340225219726564,3.7785503387451174,13.935618591308595,24.303897094726562
768,40.27259521484375,62.70989379882813,4.242265701293945,6.065497589111327,3.5292800903320316,3.849836730957031,13.805737304687499,24.286805725097658
896,40.40821533203125,63.01042480468751,4.191846466064454,6.110412979125976,3.5674304962158203,3.8795265197753905,13.795376586914061,24.30095062255859
1024,40.24483947753906,62.84979248046875,4.285379028320313,6.076883316040039,3.5235008239746093,3.8618785858154294,13.817843627929687,24.261372375488282
1152,40.29364929199219,62.60881958007812,4.243644714355469,6.028982543945313,3.540726470947266,3.8547393798828127,13.789152526855467,24.203875732421878
1280,40.27312316894531,62.62414550781251,4.172364807128906,6.022323226928711,3.52608642578125,3.786547088623047,13.714617919921874,24.344767761230468
1408,40.19969177246094,62.551519775390624,4.190035247802734,6.114352035522461,3.5422145843505857,3.8580223083496095,13.654296875,24.32489929199219
16 changes: 16 additions & 0 deletions tmp/data/kernel_overhead_triton_3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
m,default_short_warmup,default_long_warmup,old_short_warmup,old_long_warmup
128,13.172122192382812,16.26031036376953,99.68270874023438,157.00479736328123
256,13.035801696777344,15.866677856445314,99.156787109375,155.09566650390624
384,12.93981475830078,15.805488586425783,99.23234863281249,155.44112548828127
512,12.936399841308594,15.860694885253906,99.5135986328125,156.2333984375
640,12.829327392578124,15.808515930175782,98.91352539062501,155.63693847656248
768,12.976307678222657,15.908677673339845,98.74061889648438,155.3266357421875
896,12.938278198242188,15.859251403808592,98.16371459960938,156.2601318359375
1024,12.947705078125,15.82288055419922,98.5917236328125,154.57935791015626
1152,12.923493957519533,15.821667480468749,98.53797607421876,156.34381103515625
1280,13.184144592285156,15.847593688964844,99.61932373046875,155.9857421875
1408,13.018521118164061,15.985638427734376,99.23565673828125,156.74732666015626
1536,12.88250274658203,15.921778869628904,99.33762817382814,155.8266357421875
1664,12.968377685546875,16.023724365234372,99.28067626953126,155.62784423828126
1792,13.044995117187499,15.856655883789063,99.46033935546875,156.589599609375
1920,12.909158325195312,15.746665954589844,99.31958618164062,156.37071533203127
12 changes: 12 additions & 0 deletions tmp/data/kernel_time_triton_2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
m,default_short_warmup,default_long_warmup,default_short,default_long,python_generated_short_warmup,python_generated_long_warmup,python_generated_short,python_generated_long,python_generated_with_type_short_warmup,python_generated_with_type_long_warmup,python_generated_with_type_short,python_generated_with_type_long,python_short_warmup,python_long_warmup,python_short,python_long,noop,kernel
128,40.61387939453125,63.137384033203126,54.927435302734374,77.38821411132812,4.265983963012696,6.051993560791016,10.115277099609376,12.132486724853516,3.542204666137695,3.8160350799560545,9.327410888671876,10.027362823486328,14.097975158691407,24.697056579589844,24.70702362060547,36.27850341796875,0.05698239803314209,3.1782463073730467
256,40.33582458496094,62.78548583984375,54.483917236328125,77.15471191406249,4.227916717529297,6.16087989807129,9.912300872802735,11.982947540283202,3.5885665893554686,3.8557502746582033,9.161325073242189,9.78163833618164,13.756002807617188,24.189677429199218,24.968605041503906,35.854541015624996,0.058771198987960814,3.21673583984375
384,40.42800598144531,62.96868286132813,54.4669677734375,77.26203002929688,4.2608798980712885,6.093471908569335,9.960066986083984,12.073372650146483,3.5788864135742187,3.8517887115478517,9.253887939453126,9.95214385986328,13.77624053955078,24.302566528320312,24.748480224609377,35.99032287597656,0.057344001531600956,3.201350402832031
512,40.319558715820314,62.63114013671875,54.4920166015625,77.23946533203126,4.218121719360352,6.069235229492188,9.974784088134765,12.000460815429689,3.5364864349365237,3.8576126098632812,9.256352233886718,9.796607971191406,13.743682861328125,24.280221557617185,24.514111328125,35.96059875488281,0.0572704017162323,3.9306816101074222
640,40.34354858398437,62.828546142578126,54.381927490234375,77.16312866210937,4.210825729370117,6.043830490112305,10.794041442871093,12.035718536376953,3.5340225219726564,3.7785503387451174,10.69161605834961,10.705635070800781,13.935618591308595,24.303897094726562,24.250778198242188,35.83456726074219,0.057481598854064946,10.695871734619141
768,40.27259521484375,62.70989379882813,54.606054687500006,77.13404541015626,4.242265701293945,6.065497589111327,14.76259765625,14.332141113281251,3.5292800903320316,3.849836730957031,14.69055633544922,14.325389099121095,13.805737304687499,24.286805725097658,24.941328430175783,35.80436401367187,0.05707839727401733,14.695382690429687
896,40.40821533203125,63.01042480468751,54.44873657226563,76.72427978515626,4.191846466064454,6.110412979125976,17.730262756347656,17.71026611328125,3.5674304962158203,3.8795265197753905,17.688175964355466,17.687843322753906,13.795376586914061,24.30095062255859,24.684214782714843,36.17515258789063,0.058447998762130735,17.68880310058594
1024,40.24483947753906,62.84979248046875,54.23981323242188,76.81536865234375,4.285379028320313,6.076883316040039,20.30786895751953,20.221775817871094,3.5235008239746093,3.8618785858154294,20.27622375488281,20.217868041992187,13.817843627929687,24.261372375488282,24.50240020751953,36.02930603027344,0.057158398628234866,20.272335815429688
1152,40.29364929199219,62.60881958007812,54.22767944335938,76.7946044921875,4.243644714355469,6.028982543945313,22.737142944335936,22.74895324707031,3.540726470947266,3.8547393798828127,22.740963745117185,22.747526550292967,13.789152526855467,24.203875732421878,24.639546203613282,36.1968017578125,0.05793920159339905,22.74132843017578
1280,40.27312316894531,62.62414550781251,54.20135498046875,76.72073974609376,4.172364807128906,6.022323226928711,25.10726776123047,25.15968017578125,3.52608642578125,3.786547088623047,25.109837341308594,25.16084136962891,13.714617919921874,24.344767761230468,25.112130737304685,35.54142761230469,0.056934398412704465,25.107891845703126
1408,40.19969177246094,62.551519775390624,54.347991943359375,76.9533935546875,4.190035247802734,6.114352035522461,27.56083984375,27.57908935546875,3.5422145843505857,3.8580223083496095,27.555944824218752,27.582470703125,13.654296875,24.32489929199219,27.557598876953122,35.71769714355469,0.05751680135726929,27.554611206054688
16 changes: 16 additions & 0 deletions tmp/data/kernel_time_triton_3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
m,default_short_warmup,default_long_warmup,default_short,default_long,old_short_warmup,old_long_warmup,old_short,old_long,noop,kernel
128,13.172122192382812,16.26031036376953,21.247415161132814,24.403453063964843,99.68270874023438,157.00479736328123,113.9671630859375,172.33173828125,0.0573248028755188,3.508854293823242
256,13.035801696777344,15.866677856445314,21.176025390625,24.40224304199219,99.156787109375,155.09566650390624,112.9134033203125,172.02591552734376,0.05734720230102539,3.4825889587402346
384,12.93981475830078,15.805488586425783,21.049932861328124,24.30316162109375,99.23234863281249,155.44112548828127,112.85330810546874,171.34481201171874,0.05726079940795898,3.515891265869141
512,12.936399841308594,15.860694885253906,21.131674194335936,24.26081237792969,99.5135986328125,156.2333984375,111.773486328125,172.2840087890625,0.05712000131607056,3.947305679321289
640,12.829327392578124,15.808515930175782,21.12370910644531,24.38298187255859,98.91352539062501,155.63693847656248,112.80179443359376,171.37191162109374,0.05823680162429809,10.794204711914062
768,12.976307678222657,15.908677673339845,21.187242126464845,24.334713745117188,98.74061889648438,155.3266357421875,112.91352539062501,171.557666015625,0.05753600001335144,14.792601013183594
896,12.938278198242188,15.859251403808592,21.020480346679687,24.331706237792968,98.16371459960938,156.2601318359375,112.18370361328125,171.243310546875,0.057385599613189696,17.704345703125
1024,12.947705078125,15.82288055419922,21.352742004394532,24.318768310546876,98.5917236328125,154.57935791015626,111.21656494140625,171.61931152343752,0.057548797130584715,20.324163818359377
1152,12.923493957519533,15.821667480468749,22.77170257568359,24.31747131347656,98.53797607421876,156.34381103515625,113.10631103515625,172.35596923828126,0.05775679945945739,22.779043579101565
1280,13.184144592285156,15.847593688964844,25.121151733398438,25.190867614746093,99.61932373046875,155.9857421875,114.1717041015625,172.7505859375,0.0579584002494812,25.10643157958984
1408,13.018521118164061,15.985638427734376,27.587356567382812,27.610552978515628,99.23565673828125,156.74732666015626,113.649462890625,174.3118408203125,0.05784959793090821,27.536587524414063
1536,12.88250274658203,15.921778869628904,30.04534912109375,30.07674255371094,99.33762817382814,155.8266357421875,113.1193359375,172.77442626953123,0.057555198669433594,30.0052490234375
1664,12.968377685546875,16.023724365234372,32.46521606445313,32.48762817382812,99.28067626953126,155.62784423828126,112.13239746093751,172.67386474609376,0.057417601346969604,32.404275512695314
1792,13.044995117187499,15.856655883789063,34.92947082519531,34.96343994140625,99.46033935546875,156.589599609375,113.07142333984376,172.374169921875,0.057344001531600956,34.86159362792969
1920,12.909158325195312,15.746665954589844,37.38260192871093,37.40835876464844,99.31958618164062,156.37071533203127,112.30145263671875,172.97775878906248,0.057625597715377806,37.27654113769531
Binary file added tmp/figures/kernel_time_triton_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tmp/figures/kernel_time_triton_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading