XXXXRT666 c1a4ff476c .
2025-10-19 21:51:54 +01:00

250 lines
8.3 KiB
Python

import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from typing import Optional
from loguru import logger
from rich.console import Console, JustifyMethod
from rich.highlighter import Highlighter
from rich.logging import RichHandler
from rich.progress import Task, TextColumn
from rich.style import StyleType
from rich.table import Column
from rich.text import Text
from rich.traceback import Traceback, install
console = Console(stderr=False)
install(console=console)
def loguru_format(record):
level = record["level"].name
color = {
"DEBUG": "green",
"INFO": "blue",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bright_red",
}.get(level, "white")
return f"[bold {color}][{level}][/bold {color}] " + "{message}"
handler_with_locals = RichHandler(
console=console,
show_time=False,
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=True,
show_level=False,
markup=True,
)
handler_without_locals = RichHandler(
console=console,
show_time=False,
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=False,
show_level=False,
markup=True,
)
def local_filter(r):
return r["extra"].get("show_locals", True)
logger.remove()
logger.add(handler_with_locals, format=loguru_format, filter=local_filter)
logger.add(handler_without_locals, format=loguru_format, filter=lambda x: not local_filter(x))
class SpeedColumnToken(TextColumn):
"""Show task progress as a percentage.
Args:
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
style (StyleType, optional): Style of output. Defaults to "none".
justify (JustifyMethod, optional): Text justification. Defaults to "left".
markup (bool, optional): Enable markup. Defaults to True.
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
"""
def __init__(
self,
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
text_format_no_percentage: str = "",
style: StyleType = "none",
justify: JustifyMethod = "left",
markup: bool = True,
highlighter: Optional[Highlighter] = None,
table_column: Optional[Column] = None,
show_speed: bool = True,
) -> None:
self.text_format_no_percentage = text_format_no_percentage
self.show_speed = show_speed
super().__init__(
text_format=text_format,
style=style,
justify=justify,
markup=markup,
highlighter=highlighter,
table_column=table_column,
)
@classmethod
def render_speed(cls, speed: Optional[float]) -> Text:
"""Render the speed in iterations per second.
Args:
task (Task): A Task object.
Returns:
Text: Text object containing the task speed.
"""
if speed is None:
return Text("", style="progress.percentage")
return Text(f"{speed:.1f} token/s", style="progress.percentage")
def render(self, task: Task) -> Text:
if self.show_speed:
return self.render_speed(task.finished_speed or task.speed)
text_format = self.text_format_no_percentage if task.total is None else self.text_format
_text = text_format.format(task=task)
if self.markup:
text = Text.from_markup(_text, style=self.style, justify=self.justify)
else:
text = Text(_text, style=self.style, justify=self.justify)
if self.highlighter:
self.highlighter.highlight(text)
return text
class SpeedColumnIteration(TextColumn):
"""Show task progress as a percentage.
Args:
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
style (StyleType, optional): Style of output. Defaults to "none".
justify (JustifyMethod, optional): Text justification. Defaults to "left".
markup (bool, optional): Enable markup. Defaults to True.
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
"""
def __init__(
self,
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
text_format_no_percentage: str = "",
style: StyleType = "none",
justify: JustifyMethod = "left",
markup: bool = True,
highlighter: Optional[Highlighter] = None,
table_column: Optional[Column] = None,
show_speed: bool = True,
) -> None:
self.text_format_no_percentage = text_format_no_percentage
self.show_speed = show_speed
super().__init__(
text_format=text_format,
style=style,
justify=justify,
markup=markup,
highlighter=highlighter,
table_column=table_column,
)
@classmethod
def render_speed(cls, speed: Optional[float]) -> Text:
"""Render the speed in iterations per second.
Args:
task (Task): A Task object.
Returns:
Text: Text object containing the task speed.
"""
if speed is None:
return Text("", style="progress.percentage")
return Text(f"{speed:.1f} it/s", style="progress.percentage")
def render(self, task: Task) -> Text:
if self.show_speed:
return self.render_speed(task.finished_speed or task.speed)
text_format = self.text_format_no_percentage if task.total is None else self.text_format
_text = text_format.format(task=task)
if self.markup:
text = Text.from_markup(_text, style=self.style, justify=self.justify)
else:
text = Text(_text, style=self.style, justify=self.justify)
if self.highlighter:
self.highlighter.highlight(text)
return text
def tb(show_locals: bool = True):
exc_type, exc_value, exc_tb = sys.exc_info()
assert exc_type
assert exc_value
tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
return tb
__all__ = ["logger", "console", "tb", "SpeedColumnToken", "SpeedColumnIteration"]
if __name__ == "__main__":
try:
raise RuntimeError()
except Exception:
logger.bind(show_locals=False).exception("TEST")
class Timer:
def __init__(self):
self.records: dict[str, list[float]] = defaultdict(list)
self._stack: list[tuple[str, int]] = []
def __call__(self, category: str, debug=False):
timer = self
class _Ctx:
def __enter__(self):
timer._stack.append((category, time.perf_counter_ns()))
return timer # 如需在with块里调用timer方法
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.perf_counter_ns()
if not timer._stack:
raise RuntimeError("Timer stack underflow: __exit__ without matching __enter__")
cat, start = timer._stack.pop()
if cat != category:
raise RuntimeError(f"Mismatched timer context: expected '{cat}', got '{category}'")
elapsed_sec = (end - start) / 1e9
timer.records[cat].append(elapsed_sec)
return False
if debug:
return _Ctx()
else:
return nullcontext()
def clear(self):
self.records.clear()
self._stack.clear()
def summary(self):
for cat, times in self.records.items():
total = sum(times)
avg = total / len(times) if times else 0.0
print(f"{cat}: count={len(times)}, total={total:.6f}s, avg={avg:.6f}s")
timer = Timer()