2025-09-06 22:58:58 +08:00

204 lines
6.8 KiB
Python

import sys
from typing import Optional
from loguru import logger
from rich.console import Console, JustifyMethod
from rich.highlighter import Highlighter
from rich.logging import RichHandler
from rich.progress import Task, TextColumn
from rich.style import StyleType
from rich.table import Column
from rich.text import Text
from rich.traceback import Traceback, install
console = Console(stderr=False)
install(console=console)
def loguru_format(record):
level = record["level"].name
color = {
"DEBUG": "green",
"INFO": "blue",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bright_red",
}.get(level, "white")
return f"[bold {color}][{level}][/bold {color}] " + "{message}"
handler_with_locals = RichHandler(
console=console,
show_time=False,
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=True,
show_level=False,
markup=True,
)
handler_without_locals = RichHandler(
console=console,
show_time=False,
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=False,
show_level=False,
markup=True,
)
def local_filter(r):
return r["extra"].get("show_locals", True)
logger.remove()
logger.add(handler_with_locals, format=loguru_format, filter=local_filter)
logger.add(handler_without_locals, format=loguru_format, filter=lambda x: not local_filter(x))
class SpeedColumnToken(TextColumn):
"""Show task progress as a percentage.
Args:
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
style (StyleType, optional): Style of output. Defaults to "none".
justify (JustifyMethod, optional): Text justification. Defaults to "left".
markup (bool, optional): Enable markup. Defaults to True.
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
"""
def __init__(
self,
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
text_format_no_percentage: str = "",
style: StyleType = "none",
justify: JustifyMethod = "left",
markup: bool = True,
highlighter: Optional[Highlighter] = None,
table_column: Optional[Column] = None,
show_speed: bool = True,
) -> None:
self.text_format_no_percentage = text_format_no_percentage
self.show_speed = show_speed
super().__init__(
text_format=text_format,
style=style,
justify=justify,
markup=markup,
highlighter=highlighter,
table_column=table_column,
)
@classmethod
def render_speed(cls, speed: Optional[float]) -> Text:
"""Render the speed in iterations per second.
Args:
task (Task): A Task object.
Returns:
Text: Text object containing the task speed.
"""
if speed is None:
return Text("", style="progress.percentage")
return Text(f"{speed:.1f} token/s", style="progress.percentage")
def render(self, task: Task) -> Text:
if self.show_speed:
return self.render_speed(task.finished_speed or task.speed)
text_format = self.text_format_no_percentage if task.total is None else self.text_format
_text = text_format.format(task=task)
if self.markup:
text = Text.from_markup(_text, style=self.style, justify=self.justify)
else:
text = Text(_text, style=self.style, justify=self.justify)
if self.highlighter:
self.highlighter.highlight(text)
return text
class SpeedColumnIteration(TextColumn):
"""Show task progress as a percentage.
Args:
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
style (StyleType, optional): Style of output. Defaults to "none".
justify (JustifyMethod, optional): Text justification. Defaults to "left".
markup (bool, optional): Enable markup. Defaults to True.
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
"""
def __init__(
self,
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
text_format_no_percentage: str = "",
style: StyleType = "none",
justify: JustifyMethod = "left",
markup: bool = True,
highlighter: Optional[Highlighter] = None,
table_column: Optional[Column] = None,
show_speed: bool = True,
) -> None:
self.text_format_no_percentage = text_format_no_percentage
self.show_speed = show_speed
super().__init__(
text_format=text_format,
style=style,
justify=justify,
markup=markup,
highlighter=highlighter,
table_column=table_column,
)
@classmethod
def render_speed(cls, speed: Optional[float]) -> Text:
"""Render the speed in iterations per second.
Args:
task (Task): A Task object.
Returns:
Text: Text object containing the task speed.
"""
if speed is None:
return Text("", style="progress.percentage")
return Text(f"{speed:.1f} it/s", style="progress.percentage")
def render(self, task: Task) -> Text:
if self.show_speed:
return self.render_speed(task.finished_speed or task.speed)
text_format = self.text_format_no_percentage if task.total is None else self.text_format
_text = text_format.format(task=task)
if self.markup:
text = Text.from_markup(_text, style=self.style, justify=self.justify)
else:
text = Text(_text, style=self.style, justify=self.justify)
if self.highlighter:
self.highlighter.highlight(text)
return text
def tb(show_locals: bool = True):
exc_type, exc_value, exc_tb = sys.exc_info()
assert exc_type
assert exc_value
tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
return tb
__all__ = ["logger", "console", "tb", "SpeedColumnToken", "SpeedColumnIteration"]
if __name__ == "__main__":
try:
raise RuntimeError()
except Exception:
logger.bind(show_locals=False).exception("TEST")