Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from __future__ import annotations

import os
from ..utils.logger import setup_logger

logger = setup_logger()

if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True'
print("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.")
logger.info("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.")

if not os.environ.get("CUDA_DEVICE_ORDER", None):
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
print("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for correctness.")
logger.info("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for correctness.")

import sys # noqa: E402

Expand All @@ -50,7 +53,6 @@
from ..quantization.gptq import CPU # noqa: E402
from ..utils import BACKEND # noqa: E402
from ..utils.eval import EVAL # noqa: E402
from ..utils.logger import setup_logger # noqa: E402
from ..utils.model import check_and_get_model_type, find_modules # noqa: E402
from ..utils.torch import torch_empty_cache # noqa: E402
from .base import BaseGPTQModel, QuantizeConfig # noqa: E402
Expand Down Expand Up @@ -109,8 +111,6 @@
random.seed(787)
numpy.random.seed(787)

logger = setup_logger()

MODEL_MAP = {
"bloom": BloomGPTQ,
"gpt_neox": GPTNeoXGPTQ,
Expand Down
47 changes: 32 additions & 15 deletions gptqmodel/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import logging
import sys
from enum import Enum
from typing import Callable

from colorlog import ColoredFormatter

from gptqmodel.utils.terminal import terminal_size

# global static/shared logger instance
logger = None
last_logging_src = 1 # one for logger, 2 for progressbar
last_pb_instance = None # one for logger, 2 for progressbar

def update_logging_src(src: int):
global last_logging_src
last_logging_src = src
def update_last_pb_instance(src) -> None:
global last_pb_instance
last_pb_instance = src

class LEVEL(str, Enum):
CRITICAL = "CRITICAL"
Expand Down Expand Up @@ -127,21 +127,38 @@ def __init__(self, name):
self.error = self.error_cls(logger=self)

def _process(self, level: LEVEL, msg, *args, **kwargs):
global last_logging_src
if last_logging_src == 2:
print(" ", flush=True)
last_logging_src = 1
from gptqmodel.utils.progress import ProgressBar # hack: circular import

columns, _ = terminal_size()
columns -= 10 # minus level and spaces
str_msg = str(msg)
columns -= len(str_msg)

global last_pb_instance
if isinstance(last_pb_instance, ProgressBar) and not last_pb_instance.closed:
buf = f'\r'
if columns > 0:
str_msg += " " * columns

print(buf,end='',flush=True)

if level == LEVEL.INFO:
self._info(msg, *args, **kwargs)
self._info(str_msg, *args, **kwargs)
elif level == LEVEL.WARN:
self._warning(msg, *args, **kwargs)
self._warning(str_msg, *args, **kwargs)
elif level == LEVEL.ERROR:
self._error(msg, *args, **kwargs)
self._error(str_msg, *args, **kwargs)
elif level == LEVEL.DEBUG:
self._debug(msg, *args, **kwargs)
self._debug(str_msg, *args, **kwargs)
elif level == LEVEL.CRITICAL:
self._critical(msg, *args, **kwargs)
self._critical(str_msg, *args, **kwargs)

if isinstance(last_pb_instance, ProgressBar):
if not last_pb_instance.closed:
last_pb_instance.progress()
else:
last_pb_instance = None


logging.setLoggerClass(CustomLogger)

Expand Down
57 changes: 7 additions & 50 deletions gptqmodel/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# limitations under the License.

import datetime
import os
import sys
import time
from typing import Iterable
from warnings import warn

from gptqmodel.utils.logger import setup_logger, update_logging_src
from gptqmodel.utils.logger import setup_logger, update_last_pb_instance
from gptqmodel.utils.terminal import terminal_size, terminal_size

logger = setup_logger()

Expand All @@ -41,6 +40,8 @@ def __init__(self,
fill:str = '█',
info:str = ""):

self.closed = False # active state

# max info length over the life ot the pb
self.max_info_length = len(info)

Expand Down Expand Up @@ -112,7 +113,7 @@ def log(self, bar:str, log:str, padding:str = "", end: str = ""):
else:
print(f'\r{self.info_text}{padding} |{bar}| {log}', end=end, flush=True)

update_logging_src(src=2) # let logger now we logged
update_last_pb_instance(src=self) # let logger now we logged

def __bool__(self):
if self.total is not None:
Expand Down Expand Up @@ -177,55 +178,11 @@ def __iter__(self):
yield obj

self.progress()
self.close()
return

def close(self):
pass
self.closed = True
#self.log(f"{self.fill * self.bar_length}", "100.0%", end="\n")

# copied from github.com/onsim/shutils
def terminal_size(fallback=(80, 24)):
"""Get the size of the terminal window.

For each of the two dimensions, the environment variable, COLUMNS
and LINES respectively, is checked. If the variable is defined and
the value is a positive integer, it is used.

When COLUMNS or LINES is not defined, which is the common case,
the terminal connected to sys.__stdout__ is queried
by invoking os.get_terminal_size.

If the terminal size cannot be successfully queried, either because
the system doesn't support querying, or because we are not
connected to a terminal, the value given in fallback parameter
is used. Fallback defaults to (80, 24) which is the default
size used by many terminal emulators.

The value returned is a named tuple of type os.terminal_size.
"""
# columns, lines are the working values
try:
columns = int(os.environ['COLUMNS'])
except (KeyError, ValueError):
columns = 0

try:
lines = int(os.environ['LINES'])
except (KeyError, ValueError):
lines = 0

# only query if necessary
if columns <= 0 or lines <= 0:
try:
size = os.get_terminal_size(sys.__stdout__.fileno())
except (AttributeError, ValueError, OSError):
# stdout is None, closed, detached, or not a terminal, or
# os.get_terminal_size() is unsupported
size = os.terminal_size(fallback)
if columns <= 0:
columns = size.columns or fallback[0]
if lines <= 0:
lines = size.lines or fallback[1]

return (columns, lines)

49 changes: 49 additions & 0 deletions gptqmodel/utils/terminal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# copied from github.com/onsim/shutils
import os
import sys


def terminal_size(fallback=(80, 24)):
"""Get the size of the terminal window.

For each of the two dimensions, the environment variable, COLUMNS
and LINES respectively, is checked. If the variable is defined and
the value is a positive integer, it is used.

When COLUMNS or LINES is not defined, which is the common case,
the terminal connected to sys.__stdout__ is queried
by invoking os.get_terminal_size.

If the terminal size cannot be successfully queried, either because
the system doesn't support querying, or because we are not
connected to a terminal, the value given in fallback parameter
is used. Fallback defaults to (80, 24) which is the default
size used by many terminal emulators.

The value returned is a named tuple of type os.terminal_size.
"""
# columns, lines are the working values
try:
columns = int(os.environ['COLUMNS'])
except (KeyError, ValueError):
columns = 0

try:
lines = int(os.environ['LINES'])
except (KeyError, ValueError):
lines = 0

# only query if necessary
if columns <= 0 or lines <= 0:
try:
size = os.get_terminal_size(sys.__stdout__.fileno())
except (AttributeError, ValueError, OSError):
# stdout is None, closed, detached, or not a terminal, or
# os.get_terminal_size() is unsupported
size = os.terminal_size(fallback)
if columns <= 0:
columns = size.columns or fallback[0]
if lines <= 0:
lines = size.lines or fallback[1]

return (columns, lines)
66 changes: 0 additions & 66 deletions test_prepare_dataset.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_quant_and_eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_quant_and_eora(self):
torch_empty_cache()

# BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
base_bench = bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)

Expand Down