Skip to content
Draft
116 changes: 116 additions & 0 deletions tests/examples/test_load_latency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import pytest
import torch
import triton
import triton.language as tl
import numpy as np
import iris
from iris._mpi_helpers import mpi_allgather
from examples.common.utils import read_realtime


@triton.jit()
def ping_pong(
data,
n_elements,
skip,
niter,
flag,
curr_rank,
peer_rank,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
mm_begin_timestamp_ptr: tl.tensor = None,
mm_end_timestamp_ptr: tl.tensor = None,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

data_mask = offsets < n_elements
flag_mask = offsets < 1
time_stmp_mask = offsets < 1

for i in range(niter + skip):
if i == skip:
start = read_realtime()
tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
token_first_done = i + 1
token_second_done = i + 2
if curr_rank == first_rank:
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, flag_mask)
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done:
pass
else:
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done:
pass
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious about the results you are getting. Also, do you think the load and store API would be better here to avoid the local load and store?

Copy link
Contributor Author

@astroC86 astroC86 Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am unable to allocate 8 mi300x's at the moment on amd cloud. I agree with you on the load store part 👍

this maybe relevent, using nvbandwidth on H100 i get the following :

Running device_to_device_latency_sm.
Device to Device Latency SM GPU(row) <-> GPU(column) (ns)
           0         1         2         3
 0       N/A    549.34    545.77    545.35
 1    550.28       N/A    660.95    659.08
 2    548.35    547.22       N/A    544.26
 3    545.49    831.96    543.86       N/A

the output from the trition code is:

R0 R1 R2 R3
R0 0.000000 722.239990 701.119995 683.200012
R1 722.559998 0.000000 736.000000 727.679993
R2 701.440002 735.679993 0.000000 712.640015
R3 683.200012 727.679993 712.960022 0.000000

%error ((triton - nvband)*100/nvband):

  R0 R1 R2 R3
R0 N/A 31.47% 28.46% 25.28%
R1 31.31% N/A 11.35% 10.41%
R2 27.92% 34.44% N/A 30.94%
R3 25.25% -12.53% 31.09% N/A

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting. How does it compare after using the load/store?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, how did that even work? read_realtime is using AMD GCN assembly. Did you change that to the equivalent PTX?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I just tested this on MI300X and it seems to deadlock.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was what I initially had in mind for the microbenchmark. I am surprised you didn’t need to accumulate the result here. I remember everything was getting optimized away when we wrote similar code for the all load benchmark.

We wanted to add the cache modifiers and volatile arguments for a while but we haven’t yet. Let me think about this a bit more.

Copy link
Contributor Author

@astroC86 astroC86 Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised you didn’t need to accumulate the result here. I remember everything was getting optimized away when we wrote similar code for the all load benchmark.

yea without the cache modifier and volatile it gets optimized away

We wanted to add the cache modifiers and volatile arguments for a while but we haven’t yet. Let me think about this a bit more.

No worries

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These numbers are not in nanoseconds, these are in clock cycles, yes?
See ISA.
image
We have been using this functions to find the clock.

Copy link
Collaborator

@mawad-amd mawad-amd Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also push your PTX to the CUDA port branch as well. You can just comment out the CDNA assembly over there for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will do
this is the ptx i have been using:

@triton.jit
def read_realtime():
    tmp = tl.inline_asm_elementwise(
        asm="mov.u64 $0, %globaltimer;",
        constraints=("=l"),
        args=[],
        dtype=tl.int64,
        is_pure=False,
        pack=1,
    )
    return tmp

iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask)

stop = read_realtime()
tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)


if __name__ == "__main__":
dtype = torch.int32
heap_size = 1 << 32
shmem = iris.iris(heap_size)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
cur_rank = shmem.get_rank()

BLOCK_SIZE = 1
BUFFER_LEN = 1

iter = 100
skip = 10
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")

local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")

source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
flag = shmem.ones(1, dtype=dtype)

grid = lambda meta: (1,)
for source_rank in range(num_ranks):
for destination_rank in range(num_ranks):
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
ping_pong[grid](
source_buffer,
BUFFER_LEN,
skip,
iter,
flag,
cur_rank,
peer_for_me,
BLOCK_SIZE,
heap_bases,
mm_begin_timestamp,
mm_end_timestamp,
)
shmem.barrier()

for destination_rank in range(num_ranks):
local_latency[destination_rank] = (
mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]
) / iter

latency_matrix = mpi_allgather(local_latency.cpu())

if cur_rank == 0:
with open("latency.txt", "w") as f:
f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n")
for i in range(num_ranks):
row_entries = []
for j in range(num_ranks):
val = float(latency_matrix[i, j])
row_entries.append(f"{val:0.6f}")
line = f"R{i}," + ", ".join(row_entries) + "\n"
f.write(line)
Loading