Skip to content

Conversation

@iotamudelta
Copy link

@iotamudelta iotamudelta commented Aug 12, 2024

  • remove volatile keywords (they have a different meaning on CUDA than HIP)
  • optimize and scope locking
  • increase cutoff for custom all reduce to 16 MB (based on perf data)
  • tune block number for MI300X
  • roughly 10% performance increase for non-latency bound sizes, cutoff vis-a-vis RCCL at 16 MB
  • while there, add missing MPI_Finalize() to test

While there, add missing finalize.
Increase sampling area to capture crossover.
@iotamudelta iotamudelta requested a review from gshtras August 12, 2024 21:38
@gshtras
Copy link
Collaborator

gshtras commented Aug 12, 2024

Looks good. Just need to fix the linters

@gshtras gshtras merged commit 636ff01 into ROCm:main Aug 14, 2024
@dllehr-amd
Copy link
Collaborator

Hey Folks..wanted to make a quick comment. This CAR will require rocm 6.2 to compile. the scoping intrinsics are introduced in LLVM in the 6.2 release. So we may see compile errors on __MEMORY_SCOPE_DEVICE etc.

}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this line is removed, who will update self_sg->_flag?

}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this line is removed, who will update self_sg->_flag?

if (threadIdx.x < ngpus) {
// reset flag for next time
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original implementation using resetting flag which is prone to race condition. Thus we have seen occasional hang during long running workload. self_sg->_flag was introduced to make the flag incrementing. I would prefer keep this new mechanism for stability.

// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
#ifdef USE _ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is fixed in #137, but this should easily not pass review. Let's spend more time ensuring code quality and that tests pass before merging.

@mawong-amd
Copy link

Hey Folks..wanted to make a quick comment. This CAR will require rocm 6.2 to compile. the scoping intrinsics are introduced in LLVM in the 6.2 release. So we may see compile errors on __MEMORY_SCOPE_DEVICE etc.

We should not push out releases where the default settings (ROCm 6.1) do not compile. Again, I understand this is already fixed by #137 but hotfixes should be kept to a minimum. Especially when this issue is so readily detectable.

gshtras added a commit that referenced this pull request Aug 15, 2024
…stood

Revert "Optimize custom all reduce (#130)"

This reverts commit 636ff01.
gshtras added a commit that referenced this pull request Aug 15, 2024
* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Make CAR ROCm 6.1 compatible. (#137)"

This reverts commit 4d2dda6.

* Per @iotamudelta suggestion until the deadlocks issue is better understood
Revert "Optimize custom all reduce (#130)"

This reverts commit 636ff01.
gshtras pushed a commit that referenced this pull request Aug 22, 2024
* First version

* Revert error.

While there, add missing finalize.

* Use the correct defaults for ROCm.

Increase sampling area to capture crossover.

* Scope end_sync as well.

* Guard only volatile keyword for ifndef USE_ROCM

* Document crossover
kkHuang-amd pushed a commit that referenced this pull request Aug 23, 2024
* First version

* Revert error.

While there, add missing finalize.

* Use the correct defaults for ROCm.

Increase sampling area to capture crossover.

* Scope end_sync as well.

* Guard only volatile keyword for ifndef USE_ROCM

* Document crossover
sogalin pushed a commit to sogalin/vllm that referenced this pull request Sep 3, 2024
* First version

* Revert error.

While there, add missing finalize.

* Use the correct defaults for ROCm.

Increase sampling area to capture crossover.

* Scope end_sync as well.

* Guard only volatile keyword for ifndef USE_ROCM

* Document crossover
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants