- 
                Notifications
    You must be signed in to change notification settings 
- Fork 50
Optimize custom all reduce #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
While there, add missing finalize.
Increase sampling area to capture crossover.
| Looks good. Just need to fix the linters | 
| Hey Folks..wanted to make a quick comment. This CAR will require rocm 6.2 to compile.  the scoping intrinsics are introduced in LLVM in the 6.2 release. So we may see compile errors on  | 
| } | ||
| __syncthreads(); | ||
| // use one thread to update flag | ||
| if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this line is removed, who will update self_sg->_flag?
| } | ||
| __syncthreads(); | ||
| // use one thread to update flag | ||
| if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this line is removed, who will update self_sg->_flag?
| if (threadIdx.x < ngpus) { | ||
| // reset flag for next time | ||
| __scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, | ||
| __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original implementation using resetting flag which is prone to race condition. Thus we have seen occasional hang during long running workload. self_sg->_flag was introduced to make the flag incrementing. I would prefer keep this new mechanism for stability.
| // run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); | ||
| // } | ||
| // } | ||
| #ifdef USE _ROCM | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this is fixed in #137, but this should easily not pass review. Let's spend more time ensuring code quality and that tests pass before merging.
| 
 We should not push out releases where the default settings (ROCm 6.1) do not compile. Again, I understand this is already fixed by #137 but hotfixes should be kept to a minimum. Especially when this issue is so readily detectable. | 
* Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Make CAR ROCm 6.1 compatible. (#137)" This reverts commit 4d2dda6. * Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Optimize custom all reduce (#130)" This reverts commit 636ff01.
* First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover
* First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover
* First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover
Uh oh!
There was an error while loading. Please reload this page.