From e40beba8747184422fdf12dbb4b70886b5ddd050 Mon Sep 17 00:00:00 2001 From: lakshith Date: Thu, 17 Oct 2024 13:56:18 +0530 Subject: [PATCH] format token viz --- python/inspectus/token_viz.py | 45 ++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/python/inspectus/token_viz.py b/python/inspectus/token_viz.py index 58da382..87a0731 100644 --- a/python/inspectus/token_viz.py +++ b/python/inspectus/token_viz.py @@ -2,30 +2,37 @@ from inspectus.utils import init_inline_viz import numpy as np -def visualize_tokens(tokens: List[str], values: Dict[str, List[float]], - token_info: Optional[List[str]], - remove_padding: bool, color: str, theme: str): - if token_info is None: - token_info = [{} for _ in range(len(tokens))] - value_names = list(values.keys()) - value_names.sort() +def visualize_tokens( + tokens: List[str], + values: Dict[str, List[float]], + token_info: Optional[List[str]], + remove_padding: bool, + color: str, + theme: str, +): + if token_info is None: + token_info = [{} for _ in range(len(tokens))] - values = np.stack([values[name] for name in value_names]) - - normalized_values = (values - np.min(values, axis=1, keepdims=True)) / (np.max(values, axis=1, keepdims=True) - np.min(values, axis=1, keepdims=True)) + value_names = list(values.keys()) + value_names.sort() - from uuid import uuid1 - import json + values = np.stack([values[name] for name in value_names]) - elem_id = 'id_' + uuid1().hex + normalized_values = (values - np.min(values, axis=1, keepdims=True)) / ( + np.max(values, axis=1, keepdims=True) - np.min(values, axis=1, keepdims=True) + ) - html = f'
' + from uuid import uuid1 + import json - script = f'' + elem_id = "id_" + uuid1().hex - from IPython.display import display, HTML - init_inline_viz() - display(HTML(html + script)) + html = f'
' - \ No newline at end of file + script = f"" + + from IPython.display import display, HTML + + init_inline_viz() + display(HTML(html + script))