diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 50fb8e70bd47..bd818ef1d01a 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -1,3 +1,4 @@ +import os import random from pathlib import Path @@ -13,6 +14,8 @@ rtol = 1e-5 atol = 1 +CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37")) + class Gemm: @@ -24,7 +27,7 @@ def __init__(self, m, n, k, indtype, outdtype, rocblas_decode=False): self.outdtype = outdtype self.use_rocblas = (indtype == outdtype and indtype is not torch.float8_e4m3fnuz) - self.nb = 37 + self.nb = CACHE_INVALIDATE_BUFFERS self.inp = torch.randn((self.n, self.k), device='cuda').to(self.indtype) self.weights = torch.randn((self.m, self.k), @@ -283,6 +286,9 @@ def find_best_sols(self): soldf.loc[i, 'libtype'] = gemmobj.best_libtype soldf.loc[i, 'solidx'] = gemmobj.best_solidx soldf.loc[i, 'soltimems'] = gemmobj.best_soltime + del gemmobj + torch.cuda.empty_cache() + soldf['indtype'] = self.indtype soldf['outdtype'] = self.outdtype finaldf = pd.concat([self.gemm_problems, soldf], axis=1)