From 5ffa1b66c8752aeaa4bbfc0c395bf90f54da6e35 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 18 Dec 2024 16:04:31 -0500 Subject: [PATCH] update torchao pin: optimized shaders --- install/.pins/torchao-pin.txt | 2 +- torchchat/utils/quantize.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt index c6161e78f..2da70769c 100644 --- a/install/.pins/torchao-pin.txt +++ b/install/.pins/torchao-pin.txt @@ -1 +1 @@ -7d7c14e898eca3fe66138d2a9445755a9270b800 \ No newline at end of file +2e032c6b0de960dee554dcb08126ace718b14c6d diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 6ac2410d0..b63d42d6c 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -932,6 +932,7 @@ def quantized_model(self) -> nn.Module: libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*") libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) torch.ops.load_library(libs[0]) + print("Loaded torchao cpu ops.") except Exception as e: print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.") @@ -939,6 +940,7 @@ def quantized_model(self) -> nn.Module: libname = "libtorchao_ops_mps_aten.dylib" libpath = f"{torchao_build_path}/cmake-out/lib/{libname}" torch.ops.load_library(libpath) + print("Loaded torchao mps ops.") except Exception as e: print("Unabled to load torchao mps ops library.")