Skip to content

Commit

Permalink
Merge pull request #2 from afeldman-nm/enc_dec_t5
Browse files Browse the repository at this point in the history
Small PR for debug print statements
  • Loading branch information
js8544 authored Mar 2, 2024
2 parents db726e6 + 37fcf99 commit 4bf056b
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 48 deletions.
8 changes: 6 additions & 2 deletions examples/offline_inference_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
Output: for several prompts, compare native PyTorch & vLLM prompt completions
'''

import warnings
import torch
from vllm import LLM, SamplingParams
from transformers import T5Tokenizer, T5ForConditionalGeneration

warnings.filterwarnings("ignore",
category=UserWarning,
module="transformers.generation.utils.*")

hf_model_id = "t5-small"
dtype = "bfloat16"
prompts = [
Expand All @@ -27,7 +31,7 @@
# Native PyTorch test

# - Model and tokenizer initialization
tokenizer = T5Tokenizer.from_pretrained(hf_model_id)
tokenizer = T5Tokenizer.from_pretrained(hf_model_id, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(hf_model_id).to(
dtype=dtype_obj)

Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(
self.num_cpu_blocks = None

def metrics_info(self):
# convert cache_config to dict(key: str, value:str) for prometheus metrics info
# convert cache_config to dict(key: str, value: str) for prometheus metrics info
return {key: str(value) for key, value in self.__dict__.items()}

def _verify_args(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, labelnames: List[str]):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)

# Config Information
self.info_cache_config = Info(
name='vllm:cache_config',
documentation='information of cache_config')
Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def parse_args():
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
Expand Down Expand Up @@ -245,7 +251,7 @@ async def authentication(request: Request, call_next):
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile)
44 changes: 0 additions & 44 deletions vllm/model_executor/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,26 +258,17 @@ def forward(
input_metadata: InputMetadata,
encoder_hidden_states: Optional[torch.Tensor],
) -> torch.Tensor:
# print("hidden_states shape", hidden_states.shape)
# print("hidden_states", hidden_states)
q, _ = self.q(hidden_states)

# print("q shape", q.shape)
# print("q", q)
batch_size = hidden_states.shape[0]
seq_len = hidden_states.shape[1]
prompt_len = input_metadata.prompt_lens.max().item()
context_len = input_metadata.context_lens.max().item()
context_len = max(context_len, 1)
# print("batch_size", batch_size)
# print("seq_len", seq_len)
# print("prompt_len", prompt_len)
# print("context_len", context_len)

block_size = 16

if not self.is_decoder:
# print("encoder self attention!")
assert kv_cache is None
# Encoder self attention, no cache operations
k, _ = self.k(hidden_states)
Expand All @@ -293,12 +284,9 @@ def forward(
input_metadata.prompt_lens[i]:, ] = torch.finfo(
input_metadata.attn_bias.dtype).min

# print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape)
# print("input_metadata.attn_bias", input_metadata.attn_bias)
attn_output = self.attn(q, k, v, input_metadata)

elif not self.is_cross:
# print("decoder self attention!")
# Decoder self attention
k, _ = self.k(hidden_states)
v, _ = self.v(hidden_states)
Expand All @@ -308,32 +296,23 @@ def forward(
1 if input_metadata.is_prompt else context_len,
(context_len + block_size - 1) // block_size *
block_size).repeat(batch_size, 1, 1, 1)
# print("position_bias shape", position_bias.shape)
# print("position_bias", position_bias)
input_metadata.attn_bias = position_bias[:, :,
-seq_len:, :].contiguous(
)
# print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape)
# print("input_metadata.attn_bias", input_metadata.attn_bias)

key_cache, value_cache = kv_cache

attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)

else:
# print("cross attention!")
# Cross attention

key_cache, value_cache = kv_cache
if input_metadata.is_prompt:
assert encoder_hidden_states is not None
k, _ = self.k(encoder_hidden_states)
v, _ = self.v(encoder_hidden_states)
# print("k shape", k.shape)
# for i in range(k.shape[0]):
# for j in range(k.shape[1]):
# print(f"key at batch {i} and pos {j}: ", k[i, j, :].reshape(1, 8, 64))
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata)
else:
Expand Down Expand Up @@ -369,16 +348,12 @@ def forward(
input_metadata: InputMetadata,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
print("self attention input shape: ", normed_hidden_states.shape)
print("self_attention input: ", normed_hidden_states)
attention_output = self.SelfAttention(
hidden_states=normed_hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
encoder_hidden_states=None,
)
print("self attention output shape: ", attention_output.shape)
print("self_attention output: ", attention_output)
hidden_states = hidden_states + attention_output
return hidden_states

Expand Down Expand Up @@ -408,16 +383,12 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor],
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
print("cross attention input shape: ", normed_hidden_states.shape)
print("cross_attention input: ", normed_hidden_states)
attention_output = self.EncDecAttention(
hidden_states=normed_hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
encoder_hidden_states=encoder_hidden_states,
)
print("cross attention output shape: ", attention_output.shape)
print("cross_attention output: ", attention_output)
hidden_states = hidden_states + attention_output
return hidden_states

Expand Down Expand Up @@ -521,11 +492,8 @@ def forward(
input_metadata: InputMetadata,
encoder_hidden_states: Optional[torch.Tensor],
) -> torch.Tensor:
# print("input_ids: ", input_ids)
hidden_states = self.embed_tokens(input_ids)

# print("hidden_states shape: ", hidden_states.shape)
# print("hidden_states: ", hidden_states)
for i, layer_module in enumerate(self.block):
kv_cache = kv_caches[i] if self.is_decoder else None

Expand All @@ -539,15 +507,6 @@ def forward(
hidden_states = layer_outputs

hidden_states = self.final_layer_norm(hidden_states)
# if encoder_hidden_states is not None:
# print("hidden_states shape:" , hidden_states.shape)
# print("encoder_hidden_states shape:" , encoder_hidden_states.shape)
# # Attach encoder hidden states
# hidden_states = torch.cat(
# [encoder_hidden_states, hidden_states], dim=1
# )
print("final_hidden_states shape: ", hidden_states.shape)
print("final_hidden_states: ", hidden_states)
return hidden_states


Expand Down Expand Up @@ -579,9 +538,6 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
# print("input_ids shape: ", input_ids.shape)
# print("input_ids: ", input_ids)
# print("input_metadata: ", input_metadata)
if input_metadata.is_prompt:
# prompt run, need to run encoder once
hidden_states = self.encoder(input_ids, kv_caches, input_metadata,
Expand Down

0 comments on commit 4bf056b

Please sign in to comment.