-
Notifications
You must be signed in to change notification settings - Fork 23
Implements latency test #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
astroC86
wants to merge
15
commits into
ROCm:main
Choose a base branch
from
astroC86:astroC86/load-store-latency
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 7 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
5fabfef
inital latency test
astroC86 a187495
Apply Ruff auto-fixes
github-actions[bot] ad03093
initial impl of latency
astroC86 e72704a
Apply Ruff auto-fixes
github-actions[bot] f4adf5f
got rid of atomic timestamp store
astroC86 4b8bc7a
increase warmup time
astroC86 5eff862
cleanup
astroC86 606853e
addressing comments
astroC86 a3e9023
Fix deadlock
astroC86 003b273
Apply Ruff auto-fixes
github-actions[bot] f537a20
Merge branch 'main' into astroC86/load-store-latency
astroC86 56ad603
Rewrote latency test
astroC86 b301385
Fix latency time measurement
astroC86 8620fa3
Apply Ruff auto-fixes
github-actions[bot] c21cd96
Merge branch 'main' into astroC86/load-store-latency
mawad-amd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-License-Identifier: MIT | ||
| # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| import pytest | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| import numpy as np | ||
| import iris | ||
| from iris._mpi_helpers import mpi_allgather | ||
| from examples.common.utils import read_realtime | ||
|
|
||
|
|
||
| @triton.jit() | ||
| def ping_pong( | ||
| data, | ||
| n_elements, | ||
| skip, | ||
| niter, | ||
| flag, | ||
| curr_rank, | ||
| peer_rank, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| heap_bases: tl.tensor, | ||
| mm_begin_timestamp_ptr: tl.tensor = None, | ||
| mm_end_timestamp_ptr: tl.tensor = None, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| block_start = pid * BLOCK_SIZE | ||
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
|
||
| data_mask = offsets < n_elements | ||
| flag_mask = offsets < 1 | ||
| time_stmp_mask = offsets < 1 | ||
|
|
||
| for i in range(niter + skip): | ||
| if i == skip: | ||
| start = read_realtime() | ||
| tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) | ||
| first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) | ||
| token_first_done = i + 1 | ||
| token_second_done = i + 2 | ||
| if curr_rank == first_rank: | ||
| iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) | ||
| iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, flag_mask) | ||
| while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: | ||
| pass | ||
| else: | ||
| while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: | ||
| pass | ||
| iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) | ||
| iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask) | ||
|
|
||
| stop = read_realtime() | ||
| tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| dtype = torch.int32 | ||
| heap_size = 1 << 32 | ||
| shmem = iris.iris(heap_size) | ||
| num_ranks = shmem.get_num_ranks() | ||
| heap_bases = shmem.get_heap_bases() | ||
| cur_rank = shmem.get_rank() | ||
|
|
||
| BLOCK_SIZE = 1 | ||
| BUFFER_LEN = 1 | ||
|
|
||
| iter = 100 | ||
| skip = 10 | ||
| mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") | ||
| mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") | ||
|
|
||
| local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") | ||
|
|
||
| source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) | ||
| flag = shmem.ones(1, dtype=dtype) | ||
|
|
||
| grid = lambda meta: (1,) | ||
| for source_rank in range(num_ranks): | ||
| for destination_rank in range(num_ranks): | ||
| if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: | ||
| peer_for_me = destination_rank if cur_rank == source_rank else source_rank | ||
| ping_pong[grid]( | ||
| source_buffer, | ||
| BUFFER_LEN, | ||
| skip, | ||
| iter, | ||
| flag, | ||
| cur_rank, | ||
| peer_for_me, | ||
| BLOCK_SIZE, | ||
| heap_bases, | ||
| mm_begin_timestamp, | ||
| mm_end_timestamp, | ||
| ) | ||
| shmem.barrier() | ||
|
|
||
| for destination_rank in range(num_ranks): | ||
| local_latency[destination_rank] = ( | ||
| mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank] | ||
| ) / iter | ||
|
|
||
| latency_matrix = mpi_allgather(local_latency.cpu()) | ||
|
|
||
| if cur_rank == 0: | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with open("latency.txt", "w") as f: | ||
| f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n") | ||
| for i in range(num_ranks): | ||
| row_entries = [] | ||
| for j in range(num_ranks): | ||
| val = float(latency_matrix[i, j]) | ||
| row_entries.append(f"{val:0.6f}") | ||
| line = f"R{i}," + ", ".join(row_entries) + "\n" | ||
| f.write(line) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious about the results you are getting. Also, do you think the load and store API would be better here to avoid the local load and store?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am unable to allocate 8 mi300x's at the moment on amd cloud. I agree with you on the load store part 👍
this maybe relevent, using nvbandwidth on H100 i get the following :
the output from the trition code is:
%error ((triton - nvband)*100/nvband):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting. How does it compare after using the load/store?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, how did that even work?
read_realtimeis using AMD GCN assembly. Did you change that to the equivalent PTX?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I just tested this on MI300X and it seems to deadlock.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was what I initially had in mind for the microbenchmark. I am surprised you didn’t need to accumulate the result here. I remember everything was getting optimized away when we wrote similar code for the all load benchmark.
We wanted to add the cache modifiers and volatile arguments for a while but we haven’t yet. Let me think about this a bit more.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea without the cache modifier and volatile it gets optimized away
No worries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These numbers are not in nanoseconds, these are in clock cycles, yes?

See ISA.
We have been using this functions to find the clock.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also push your PTX to the CUDA port branch as well. You can just comment out the CDNA assembly over there for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok will do
this is the ptx i have been using: