Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def skip(*args, **kwargs):
elif is_flash_attn_2_available() and not has_attn_implementation:
args = {USE_FLASH_ATTENTION_2: True}

logger.info("Auto enabling flash attention2")
logger.info("Optimize: Auto enabling flash attention2")

model = cls.loader.from_config(
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, **args
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from gptqmodel.utils.backend import BACKEND

from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger

logger = setup_logger()

class BaseQuantLinear(nn.Module):
SUPPORTS_BITS: List[int] = None
Expand Down Expand Up @@ -344,6 +346,7 @@ def validate_device(cls, device: DEVICE):
# override me, to perform any torch.compile logic on the kernel pre forward
def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False):
self.optimized = True
logger.info.once(f"Optimize: `{self.__class__.__name__}` compilation triggered.")
pass

class PackableQuantLinear(BaseQuantLinear):
Expand Down
12 changes: 12 additions & 0 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,18 @@ def __init__(
if kwargs.get("name") is not None and kwargs.get("lm_head_name") is not None:
self.is_lm_head = kwargs["name"] == kwargs["lm_head_name"]

# auto-optimize on post init
# self.optimize()

# def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False):
# if self.optimized:
# return
#
# # compile dequantize
# self.forward = torch_compile(self.forward, backend=backend, mode=mode, fullgraph=fullgraph)
#
# super().optimize()

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if marlin_import_exception is not None:
Expand Down
6 changes: 4 additions & 2 deletions gptqmodel/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ def _process(self, level: LEVEL, msg, *args, **kwargs):
else:
last_pb_instance = None


original_logger_cls = logging.getLoggerClass()
logging.setLoggerClass(CustomLogger)

logger = logging.getLogger(__name__)
logger = logging.getLogger("gptqmodel")
logging.setLoggerClass(original_logger_cls)

logger.propagate = False
logger.setLevel(logging.DEBUG)

Expand Down
5 changes: 4 additions & 1 deletion gptqmodel/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,14 @@ def _comparable(self):
def __hash__(self):
return id(self)

def iter(self):
self.current_iteration += 1

def __iter__(self):
iterable = self.iterable

for obj in iterable:
self.current_iteration+=1
self.iter()
self.progress()
yield obj

Expand Down
3 changes: 2 additions & 1 deletion tests/inference_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True,
if warmup_runs > 0:
pb = ProgressBar(range(warmup_runs))
for i in pb:
pb.info(f"warmup run index {i} of {self.NUM_RUNS - 1}")
pb.info(f"warmup run index {i} of {warmup_runs - 1}")
pb.progress()
start_time = time.time()
result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOEKNS, pad_token_id=tokenizer.pad_token_id)
end_time = time.time()
Expand Down