File tree Expand file tree Collapse file tree 7 files changed +13
-0
lines changed
gptqmodel/nn_modules/qlinear Expand file tree Collapse file tree 7 files changed +13
-0
lines changed Original file line number Diff line number Diff line change 1818
1919import torch
2020
21+ from ...utils .backend import BACKEND
2122from ...models ._const import DEVICE , PLATFORM
2223from ...adapter .adapter import Adapter , Lora
2324from ...nn_modules .qlinear .torch import TorchQuantLinear
@@ -80,6 +81,7 @@ def __init__(
8081 out_features = out_features ,
8182 bias = bias ,
8283 pack_dtype = pack_dtype ,
84+ backend = kwargs .pop ("backend" , BACKEND .CUDA ),
8385 adapter = adapter ,
8486 ** kwargs )
8587
Original file line number Diff line number Diff line change 2121
2222import torch
2323
24+ from ...utils .backend import BACKEND
2425from ...adapter .adapter import Adapter , Lora
2526from ...models ._const import DEVICE , PLATFORM
2627from ...nn_modules .qlinear import BaseQuantLinear
@@ -103,6 +104,7 @@ def __init__(
103104 out_features = out_features ,
104105 bias = bias ,
105106 pack_dtype = pack_dtype ,
107+ backend = kwargs .pop ("backend" , BACKEND .EXLLAMA_V1 ),
106108 adapter = adapter ,
107109 register_buffers = True ,
108110 register_buffers_in_features = in_features ,
Original file line number Diff line number Diff line change 2020
2121import torch
2222
23+ from ...utils .backend import BACKEND
2324from ...adapter .adapter import Adapter , Lora
2425from ...models ._const import DEVICE , PLATFORM
2526from ...nn_modules .qlinear import BaseQuantLinear
@@ -176,6 +177,7 @@ def __init__(
176177 out_features = out_features ,
177178 bias = bias ,
178179 pack_dtype = pack_dtype ,
180+ backend = kwargs .pop ("backend" , BACKEND .EXLLAMA_V2 ),
179181 adapter = adapter ,
180182 register_buffers = True ,
181183 register_buffers_in_features = in_features ,
Original file line number Diff line number Diff line change 1818
1919import torch
2020
21+ from ...utils .backend import BACKEND
2122from ...utils .logger import setup_logger
2223from ...utils .torch import torch_compile
2324from ...adapter .adapter import Adapter , Lora
@@ -127,6 +128,7 @@ def __init__(
127128 pack_dtype = pack_dtype ,
128129 adapter = adapter ,
129130 register_buffers = True ,
131+ backend = kwargs .pop ("backend" , BACKEND .IPEX ),
130132 ** kwargs )
131133
132134 self .weight_dtype = torch .float16
Original file line number Diff line number Diff line change @@ -216,6 +216,7 @@ def __init__(
216216 out_features = out_features ,
217217 bias = bias ,
218218 pack_dtype = pack_dtype ,
219+ backend = kwargs .pop ("backend" , BACKEND .MARLIN ),
219220 adapter = adapter ,
220221 register_buffers = False ,
221222 ** kwargs )
Original file line number Diff line number Diff line change 1919import torch .nn as nn
2020from transformers import PreTrainedModel
2121
22+ from ...utils .backend import BACKEND
2223from ...models ._const import DEVICE , PLATFORM
2324from ...utils .torch import torch_compile
2425from ...adapter .adapter import Adapter , Lora
@@ -67,6 +68,7 @@ def __init__(
6768 out_features = out_features ,
6869 bias = bias ,
6970 pack_dtype = pack_dtype ,
71+ backend = kwargs .pop ("backend" , BACKEND .TORCH ),
7072 adapter = adapter ,
7173 register_buffers = True ,
7274 ** kwargs )
Original file line number Diff line number Diff line change 1919import torch
2020from packaging import version
2121
22+ from ...utils .backend import BACKEND
2223from ...models ._const import DEVICE , PLATFORM
2324from ...utils .logger import setup_logger
2425from ...adapter .adapter import Adapter , Lora
@@ -95,6 +96,7 @@ def __init__(
9596 out_features = out_features ,
9697 bias = bias ,
9798 pack_dtype = pack_dtype ,
99+ backend = kwargs .pop ("backend" , BACKEND .TRITON ),
98100 adapter = adapter ,
99101 register_buffers = True ,
100102 ** kwargs )
You can’t perform that action at this time.
0 commit comments