Skip to content

Commit

Permalink
Merge pull request #402 from klei22/add_inference_colorization
Browse files Browse the repository at this point in the history
Add inference colorization
  • Loading branch information
gkielian authored Feb 24, 2025
2 parents fd56124 + 366e7e0 commit 54e4748
Showing 1 changed file with 100 additions and 3 deletions.
103 changes: 100 additions & 3 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import tiktoken
from rich import print
from rich.console import Console
from torch.nn import functional as F
from collections import OrderedDict

Expand All @@ -36,7 +37,6 @@ def parse_args():
parser.add_argument('--sample_file', type=str, default=None, help="Output file for inference")
parser.add_argument('--interactive', action=argparse.BooleanOptionalAction, help="Enable interactive generation")
parser.add_argument('--stop_string', type=str, default='~W', help="String to stop generation and allow user input")
parser.add_argument('--show_heatmaps', action=argparse.BooleanOptionalAction, help="Show heatmaps of top-k choices for each token")
parser.add_argument('--last_k_tokens', type=int, default=10, help="Number of last tokens to display in heatmaps")
parser.add_argument('--chart_type', type=str, default='heatmap', choices=['heatmap', 'barchart'], help="Type of chart to display: 'heatmap' or 'barchart'")
parser.add_argument('--block_size', type=int, default=None, help="Block size for context length, default is model's block size")
Expand All @@ -45,6 +45,18 @@ def parse_args():
parser.add_argument('--token_boundary', type=str, default=None, help="optional separator between emitted tokens")
parser.add_argument('--print_model_info', default=True, action=argparse.BooleanOptionalAction, help="print info about model before infernece")

# Output Confidence
parser.add_argument('--colorize_mode', type=str, default='minmax', choices=['minmax', 'softmax', 'softmax_top_k'],
help="Mode to colorize text: 'minmax' (default), 'softmax', or 'softmax_top_k' for softmax only over the top k vals. "
"Requires --colorize_output (enabled by default).")
parser.add_argument('--colorize_output', default=False, action=argparse.BooleanOptionalAction,
help="Colorize tokens based on their predicted probabilities. Default = True. "
"Disable with --no-colorize-output.")

# Visualizations
parser.add_argument('--show_heatmaps', action=argparse.BooleanOptionalAction, help="Show heatmaps of top-k choices for each token")


# Steering Vector Related
parser.add_argument('--save_avg_vector', type=str, default=None, help="Path to save the average vector of the start text to an .npy file")
parser.add_argument('--apply_vector_file1', type=str, default=None, help="First .npy file to load the vector for subtraction")
Expand All @@ -64,6 +76,52 @@ def parse_args():

return parser.parse_args()

def colorize_text(tokens, raw_logits, decode, colorize_mode='minmax'):
"""
Colorizes each token according to one of two modes:
- 'minmax': raw_logits is a 1D list/array of chosen-token logits.
We min-max normalize them across time, then map to R->G colors.
- 'softmax': raw_logits is a 2D list/array (T, vocab_size) containing
the *full* distribution at each step. We extract the chosen
token's probability for each step, then min-max normalize.
"""
from rich.text import Text
text = Text()

norm_values = None

if colorize_mode == 'softmax' or colorize_mode == 'softmax_top_k':
# raw_logits is shape (T, vocab_size) per step
# gather the chosen token’s probability each step
# then apply min–max to those probabilities
dist_tensor = torch.stack(raw_logits, dim=0) # shape (T, vocab_size)
chosen_probs = []
for i, dist_row in enumerate(dist_tensor):
# print(dist_row)
prob_dist = F.softmax(dist_row, dim=-1)
# print(prob_dist)
# input()
chosen_probs.append(prob_dist[tokens[i]])
values = torch.stack(chosen_probs)

norm_values = values

if colorize_mode == 'minmax':
# raw_logits is shape (T,) with each chosen-token logit
values = torch.tensor(raw_logits, dtype=torch.float32)

# Normalize the chosen values (probabilities or logits) to [0..1]
norm_values = (values - values.min()) / (values.max() - values.min() + 1e-6)

for i, token_id in enumerate(tokens):
token_str = decode([token_id])
color_val = norm_values[i].item() # 0..1
r = int((1 - color_val) * 255)
g = int(color_val * 255)
text.append(token_str, style=f"bold #{r:02x}{g:02x}00")
return text



def save_chart(probs, idx, decode, step, out_dir, last_k_tokens, chart_type, selected_token):
top_k_probs, top_k_indices = torch.topk(probs, k=probs.size(-1))
Expand Down Expand Up @@ -337,6 +395,12 @@ def main():
print(f"Validation Loss: {val_loss:.4f}")
return

# Prepare to store tokens/logits for optional colorization
tokens_for_color = []
logits_for_color = []
all_logits_for_softmax = []
saved_logits = None

x = torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...]
# Obtain vector from the specified layer and save it to a file if required
if args.save_avg_vector:
Expand Down Expand Up @@ -368,12 +432,16 @@ def main():
model.set_lsv_mixture(args.lsv_mixture)
else:
model.set_lsv_mode(1)
if args.colorize_output:
tokens_for_color.clear(); logits_for_color.clear(); all_logits_for_softmax.clear()
x = torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...]
block_size = args.block_size if args.block_size else model.config.block_size
for step in range(args.max_new_tokens):
idx_cond = x if x.size(1) <= block_size else x[:, -block_size:]
logits, _ = model(idx_cond)
logits = logits[:, -1, :] / args.temperature
if args.colorize_mode == 'softmax':
saved_logits = logits.clone()
if args.top_k is not None:
v, _ = torch.topk(logits, min(args.top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
Expand All @@ -385,11 +453,40 @@ def main():
selected_token = decode([idx_next[0].item()])
save_chart(probs, x, decode, step, out_dir, args.last_k_tokens, args.chart_type, selected_token)

# Collect data for colorization:
if args.colorize_output:
if args.colorize_mode == 'softmax':
# softmax over entire vocab
tokens_for_color.append(idx_next.item())
logits_for_color.append(saved_logits[0].clone())
elif args.colorize_mode == 'softmax_top_k':
# softmax over only top k vocab
tokens_for_color.append(idx_next.item())
logits_for_color.append(logits[0].clone())
elif args.colorize_mode == 'minmax':
# We'll do min-max normalization over chosen-token logits
tokens_for_color.append(idx_next.item())
logits_for_color.append(logits[0, idx_next.item()])

output_line = decode(x[0].tolist()).replace(separator_token, " ") if separator_token else decode(x[0].tolist())
if args.apply_vector_file1:
print(f"Scaling factor: {args.steering_vector_scaling_factor}")
print("[bold green]" + output_line)
print('---------------')
print('[bold blue]---------------')
print(f"[bold orange] Sample [bold orange]{k+1}")
print('[bold blue]---------------')
# Perform colorized printing if requested
if args.colorize_output:
console = Console()
colored_text = colorize_text(
tokens_for_color,
logits_for_color, # <--- do NOT wrap in torch.tensor(...)
decode,
colorize_mode=args.colorize_mode
)
console.print(colored_text)
else:
print("[bold green]" + output_line)

if args.sample_file:
with open(args.sample_file, "w") as file:
file.write(output_line)
Expand Down

0 comments on commit 54e4748

Please sign in to comment.