Skip to content

Commit

Permalink
Support HF & Meta models
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 4, 2024
1 parent 0749a4a commit 8e55d16
Showing 1 changed file with 79 additions and 6 deletions.
85 changes: 79 additions & 6 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,71 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig):

return checkpoint

@staticmethod
def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig):
checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight')

# We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights
# see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def unpermute(w, n_heads, dim1, dim2):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)

for i in range(config.n_layer):
for name in ['input_layernorm', 'post_attention_layernorm']:
old_key = f'model.layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight
new_key = f'transformer.h.{i}.ln_{1 if name == "input_layernorm" else 2}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)

for i in range(config.n_layer):
for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']:
old_key = f'model.layers.{i}.{name}.weight'
new_key = f'transformer.h.{i}.attn.c_attn.weight'
if name == 'self_attn.q_proj':
checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd)
else: # merge 3 weights into transformer.h.x.attn.c_attn.weight
tensor = checkpoint.pop(old_key)
if name == 'self_attn.k_proj':
tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd)
checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0)
old_key = f'model.layers.{i}.self_attn.o_proj.weight'
new_key = f'transformer.h.{i}.attn.c_proj.weight'
checkpoint[new_key] = checkpoint.pop(old_key)

ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'}
for i in range(config.n_layer):
for name in ['gate_proj', 'down_proj', 'up_proj']:
old_key = f'model.layers.{i}.mlp.{name}.weight'
new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)

checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight')

return checkpoint

@classmethod
def from_pretrained_llama3_hf(cls, model_id):
"""Loads pretrained LLaMA model weights from HuggingFace"""
from transformers import AutoModelForCausalLM, AutoTokenizer
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now"
model_args = LlamaConfig()

model = AutoModelForCausalLM.from_pretrained(model_id)
checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)

original_default_type = torch.get_default_dtype() # save the default type
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading
model = LLaMA(model_args)
model.load_state_dict(checkpoint, strict=False)
torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_id = 128004 # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it
tokenizer.stop_tokens = [tokenizer.eos_token_id]
model.tokenizer = tokenizer
return model

@classmethod
def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path):
def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path):
"""Loads pretrained LLaMA model weights from a checkpoint directory"""
model_args = LlamaConfig()

Expand All @@ -272,6 +335,9 @@ def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path):
torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type

tokenizer = Tokenizer(model_path=tokenizer_path)
# add <|end_of_text|> as the stop token for base model - this is an omission in the reference code
# the reference code only adds instruct model stop tokens...
tokenizer.stop_tokens = tokenizer.stop_tokens + [128001]
model.tokenizer = tokenizer
return model

Expand Down Expand Up @@ -608,13 +674,14 @@ def print0(*args, **kwargs):
# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
parser = argparse.ArgumentParser()
parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer")
# file system input / output
parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
parser.add_argument("--model", type=str, default="llama3.1", help="llama3.1")
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model")
# token layout for each step of the optimization
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
Expand Down Expand Up @@ -650,7 +717,7 @@ def print0(*args, **kwargs):
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 8192, "sequence length must be between 1 and 8192"
assert args.dtype in {"float32", "float16", "bfloat16"}
assert args.model in {"llama3.1"}
assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now

# create the logging directory if it does not exist
logfile = None
Expand Down Expand Up @@ -725,7 +792,10 @@ def print0(*args, **kwargs):
# init the model
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
model = LLaMA.from_pretrained_llama3(args.ckpt_dir, args.tokenizer_path)
if args.use_hf:
model = LLaMA.from_pretrained_llama3_hf(args.model)
else: # use Meta's checkpoint
model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)

model.train()
if args.compile:
Expand Down Expand Up @@ -753,7 +823,7 @@ def print0(*args, **kwargs):
logits, loss = model(x, y)
loss.backward()
# save model params, in bfloat16
model_to_size = {"llama3.1": "8B"}
model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"}
model_size_str = model_to_size[args.model] # e.g. "8B"
write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16")
# save x, y, logits, loss, and parameter gradients, for debugging C
Expand Down Expand Up @@ -824,7 +894,10 @@ def get_lr(it):
and master_process:
model.eval()
prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts']
prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
if args.use_hf:
prompt_tokens = [model.tokenizer(x).input_ids for x in prompts]
else: # Meta
prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False)
results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens]
Expand Down

0 comments on commit 8e55d16

Please sign in to comment.