Skip to content

Commit d45dfea

Browse files
maleksan85maleksan85
authored andcommitted
[FIX] Gradlib OOM on Navi and sometimes on MI (#124)
* add memory clean up after every shape and parameter to reduce cache invalidation buffers * small typo * syntax change --------- Co-authored-by: maleksan85 <[email protected]>
1 parent d26ee9b commit d45dfea

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

gradlib/gradlib/GemmTuner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import random
23
from pathlib import Path
34

@@ -13,6 +14,8 @@
1314
rtol = 1e-5
1415
atol = 1
1516

17+
CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37"))
18+
1619

1720
class Gemm:
1821

@@ -24,7 +27,7 @@ def __init__(self, m, n, k, indtype, outdtype, rocblas_decode=False):
2427
self.outdtype = outdtype
2528
self.use_rocblas = (indtype == outdtype
2629
and indtype is not torch.float8_e4m3fnuz)
27-
self.nb = 37
30+
self.nb = CACHE_INVALIDATE_BUFFERS
2831
self.inp = torch.randn((self.n, self.k),
2932
device='cuda').to(self.indtype)
3033
self.weights = torch.randn((self.m, self.k),
@@ -283,6 +286,9 @@ def find_best_sols(self):
283286
soldf.loc[i, 'libtype'] = gemmobj.best_libtype
284287
soldf.loc[i, 'solidx'] = gemmobj.best_solidx
285288
soldf.loc[i, 'soltimems'] = gemmobj.best_soltime
289+
del gemmobj
290+
torch.cuda.empty_cache()
291+
286292
soldf['indtype'] = self.indtype
287293
soldf['outdtype'] = self.outdtype
288294
finaldf = pd.concat([self.gemm_problems, soldf], axis=1)

0 commit comments

Comments
 (0)