Skip to content

Commit 271a1dd

Browse files
authored
refractor logger and have progress bar sticky to bottom of cli (#1322)
Signed-off-by: Qubitium <[email protected]>
1 parent b0fc5c7 commit 271a1dd

File tree

6 files changed

+94
-137
lines changed

6 files changed

+94
-137
lines changed

gptqmodel/models/auto.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
from __future__ import annotations
1818

1919
import os
20+
from ..utils.logger import setup_logger
21+
22+
logger = setup_logger()
2023

2124
if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None):
2225
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True'
23-
print("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.")
26+
logger.info("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.")
2427

2528
if not os.environ.get("CUDA_DEVICE_ORDER", None):
2629
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
27-
print("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for correctness.")
30+
logger.info("ENV: Auto setting CUDA_DEVICE_ORDER=PCI_BUS_ID for correctness.")
2831

2932
import sys # noqa: E402
3033

@@ -50,7 +53,6 @@
5053
from ..quantization.gptq import CPU # noqa: E402
5154
from ..utils import BACKEND # noqa: E402
5255
from ..utils.eval import EVAL # noqa: E402
53-
from ..utils.logger import setup_logger # noqa: E402
5456
from ..utils.model import check_and_get_model_type, find_modules # noqa: E402
5557
from ..utils.torch import torch_empty_cache # noqa: E402
5658
from .base import BaseGPTQModel, QuantizeConfig # noqa: E402
@@ -109,8 +111,6 @@
109111
random.seed(787)
110112
numpy.random.seed(787)
111113

112-
logger = setup_logger()
113-
114114
MODEL_MAP = {
115115
"bloom": BloomGPTQ,
116116
"gpt_neox": GPTNeoXGPTQ,

gptqmodel/utils/logger.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717
import logging
1818
import sys
1919
from enum import Enum
20-
from typing import Callable
21-
2220
from colorlog import ColoredFormatter
2321

22+
from gptqmodel.utils.terminal import terminal_size
23+
2424
# global static/shared logger instance
2525
logger = None
26-
last_logging_src = 1 # one for logger, 2 for progressbar
26+
last_pb_instance = None # one for logger, 2 for progressbar
2727

28-
def update_logging_src(src: int):
29-
global last_logging_src
30-
last_logging_src = src
28+
def update_last_pb_instance(src) -> None:
29+
global last_pb_instance
30+
last_pb_instance = src
3131

3232
class LEVEL(str, Enum):
3333
CRITICAL = "CRITICAL"
@@ -127,21 +127,38 @@ def __init__(self, name):
127127
self.error = self.error_cls(logger=self)
128128

129129
def _process(self, level: LEVEL, msg, *args, **kwargs):
130-
global last_logging_src
131-
if last_logging_src == 2:
132-
print(" ", flush=True)
133-
last_logging_src = 1
130+
from gptqmodel.utils.progress import ProgressBar # hack: circular import
131+
132+
columns, _ = terminal_size()
133+
columns -= 10 # minus level and spaces
134+
str_msg = str(msg)
135+
columns -= len(str_msg)
136+
137+
global last_pb_instance
138+
if isinstance(last_pb_instance, ProgressBar) and not last_pb_instance.closed:
139+
buf = f'\r'
140+
if columns > 0:
141+
str_msg += " " * columns
142+
143+
print(buf,end='',flush=True)
134144

135145
if level == LEVEL.INFO:
136-
self._info(msg, *args, **kwargs)
146+
self._info(str_msg, *args, **kwargs)
137147
elif level == LEVEL.WARN:
138-
self._warning(msg, *args, **kwargs)
148+
self._warning(str_msg, *args, **kwargs)
139149
elif level == LEVEL.ERROR:
140-
self._error(msg, *args, **kwargs)
150+
self._error(str_msg, *args, **kwargs)
141151
elif level == LEVEL.DEBUG:
142-
self._debug(msg, *args, **kwargs)
152+
self._debug(str_msg, *args, **kwargs)
143153
elif level == LEVEL.CRITICAL:
144-
self._critical(msg, *args, **kwargs)
154+
self._critical(str_msg, *args, **kwargs)
155+
156+
if isinstance(last_pb_instance, ProgressBar):
157+
if not last_pb_instance.closed:
158+
last_pb_instance.progress()
159+
else:
160+
last_pb_instance = None
161+
145162

146163
logging.setLoggerClass(CustomLogger)
147164

gptqmodel/utils/progress.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
# limitations under the License.
1616

1717
import datetime
18-
import os
19-
import sys
2018
import time
2119
from typing import Iterable
2220
from warnings import warn
2321

24-
from gptqmodel.utils.logger import setup_logger, update_logging_src
22+
from gptqmodel.utils.logger import setup_logger, update_last_pb_instance
23+
from gptqmodel.utils.terminal import terminal_size, terminal_size
2524

2625
logger = setup_logger()
2726

@@ -41,6 +40,8 @@ def __init__(self,
4140
fill:str = '█',
4241
info:str = ""):
4342

43+
self.closed = False # active state
44+
4445
# max info length over the life ot the pb
4546
self.max_info_length = len(info)
4647

@@ -112,7 +113,7 @@ def log(self, bar:str, log:str, padding:str = "", end: str = ""):
112113
else:
113114
print(f'\r{self.info_text}{padding} |{bar}| {log}', end=end, flush=True)
114115

115-
update_logging_src(src=2) # let logger now we logged
116+
update_last_pb_instance(src=self) # let logger now we logged
116117

117118
def __bool__(self):
118119
if self.total is not None:
@@ -177,55 +178,11 @@ def __iter__(self):
177178
yield obj
178179

179180
self.progress()
181+
self.close()
180182
return
181183

182184
def close(self):
183-
pass
185+
self.closed = True
184186
#self.log(f"{self.fill * self.bar_length}", "100.0%", end="\n")
185187

186-
# copied from github.com/onsim/shutils
187-
def terminal_size(fallback=(80, 24)):
188-
"""Get the size of the terminal window.
189-
190-
For each of the two dimensions, the environment variable, COLUMNS
191-
and LINES respectively, is checked. If the variable is defined and
192-
the value is a positive integer, it is used.
193-
194-
When COLUMNS or LINES is not defined, which is the common case,
195-
the terminal connected to sys.__stdout__ is queried
196-
by invoking os.get_terminal_size.
197-
198-
If the terminal size cannot be successfully queried, either because
199-
the system doesn't support querying, or because we are not
200-
connected to a terminal, the value given in fallback parameter
201-
is used. Fallback defaults to (80, 24) which is the default
202-
size used by many terminal emulators.
203-
204-
The value returned is a named tuple of type os.terminal_size.
205-
"""
206-
# columns, lines are the working values
207-
try:
208-
columns = int(os.environ['COLUMNS'])
209-
except (KeyError, ValueError):
210-
columns = 0
211-
212-
try:
213-
lines = int(os.environ['LINES'])
214-
except (KeyError, ValueError):
215-
lines = 0
216-
217-
# only query if necessary
218-
if columns <= 0 or lines <= 0:
219-
try:
220-
size = os.get_terminal_size(sys.__stdout__.fileno())
221-
except (AttributeError, ValueError, OSError):
222-
# stdout is None, closed, detached, or not a terminal, or
223-
# os.get_terminal_size() is unsupported
224-
size = os.terminal_size(fallback)
225-
if columns <= 0:
226-
columns = size.columns or fallback[0]
227-
if lines <= 0:
228-
lines = size.lines or fallback[1]
229-
230-
return (columns, lines)
231188

gptqmodel/utils/terminal.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# copied from github.com/onsim/shutils
2+
import os
3+
import sys
4+
5+
6+
def terminal_size(fallback=(80, 24)):
7+
"""Get the size of the terminal window.
8+
9+
For each of the two dimensions, the environment variable, COLUMNS
10+
and LINES respectively, is checked. If the variable is defined and
11+
the value is a positive integer, it is used.
12+
13+
When COLUMNS or LINES is not defined, which is the common case,
14+
the terminal connected to sys.__stdout__ is queried
15+
by invoking os.get_terminal_size.
16+
17+
If the terminal size cannot be successfully queried, either because
18+
the system doesn't support querying, or because we are not
19+
connected to a terminal, the value given in fallback parameter
20+
is used. Fallback defaults to (80, 24) which is the default
21+
size used by many terminal emulators.
22+
23+
The value returned is a named tuple of type os.terminal_size.
24+
"""
25+
# columns, lines are the working values
26+
try:
27+
columns = int(os.environ['COLUMNS'])
28+
except (KeyError, ValueError):
29+
columns = 0
30+
31+
try:
32+
lines = int(os.environ['LINES'])
33+
except (KeyError, ValueError):
34+
lines = 0
35+
36+
# only query if necessary
37+
if columns <= 0 or lines <= 0:
38+
try:
39+
size = os.get_terminal_size(sys.__stdout__.fileno())
40+
except (AttributeError, ValueError, OSError):
41+
# stdout is None, closed, detached, or not a terminal, or
42+
# os.get_terminal_size() is unsupported
43+
size = os.terminal_size(fallback)
44+
if columns <= 0:
45+
columns = size.columns or fallback[0]
46+
if lines <= 0:
47+
lines = size.lines or fallback[1]
48+
49+
return (columns, lines)

test_prepare_dataset.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

tests/test_quant_and_eora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_quant_and_eora(self):
139139
torch_empty_cache()
140140

141141
# BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA,
142-
for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
142+
for backend in [ BACKEND.MARLIN ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN
143143
base_bench = bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only
144144
eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora)
145145

0 commit comments

Comments
 (0)