Skip to content

Commit

Permalink
Importing core training config stuff from dolma refine
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 19, 2024
1 parent bab32aa commit 7d2c447
Show file tree
Hide file tree
Showing 11 changed files with 1,517 additions and 1 deletion.
Empty file added pdelfin/train/core/__init__.py
Empty file.
96 changes: 96 additions & 0 deletions pdelfin/train/core/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
from logging import Logger
from typing import Optional, Type

import smart_open
import torch
from peft.peft_model import PeftModel
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelWithLMHead,
AutoTokenizer,
)

from .config import ModelConfig
from .loggers import get_logger
from .paths import cached_path, exists, get_cache_dir, join_path, resource_to_filename

__all__ = ["load_model", "cache_merged_model"]


def get_model_cls(config: ModelConfig) -> Type[AutoModelWithLMHead]:
if config.arch == "seq2seq":
return AutoModelForSeq2SeqLM # pyright: ignore
elif config.arch == "causal" or config.arch == "vllm":
return AutoModelForCausalLM # pyright: ignore
else:
raise ValueError(f"Unsupported model architecture: {config.arch}")


def get_adapter_config(config: ModelConfig) -> dict:
local_path = cached_path(config.name_or_path)
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
with smart_open.open(adapter_config_path, "rt", encoding="utf-8") as f:
return json.load(f)
return {}


def load_model(config: ModelConfig, logger: Optional[Logger] = None) -> AutoModelWithLMHead:
logger = logger or get_logger(__file__, level="INFO")

logger.info(f"Loading model from {config.name_or_path}")
local_path = cached_path(config.name_or_path)
if local_path != config.name_or_path:
logger.info(f"Model cached at {local_path}")

if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
logger.info(f"Loading LoRA adapter from {adapter_config_path}")
with smart_open.open(adapter_config_path) as f:
adapter_config = json.load(f)
base_model_name_or_path = adapter_config["base_model_name_or_path"]
enable_lora = True
else:
base_model_name_or_path = local_path
enable_lora = False

model = get_model_cls(config).from_pretrained(
base_model_name_or_path,
device_map="auto",
trust_remote_code=config.trust_remote_code,
# low_cpu_mem_usage=model_config.low_cpu_mem_usage,
use_flash_attention_2=True if config.use_flash_attn else False,
revision=config.model_revision,
torch_dtype=torch.bfloat16 if config.use_flash_attn else getattr(torch, config.dtype),
)
logger.info(f"Successfully loaded base model from {base_model_name_or_path}")

if enable_lora:
peft_model = PeftModel.from_pretrained(model, local_path)
model = peft_model.merge_and_unload()
logger.info(f"Successfully loaded LoRA adapter from base model: {base_model_name_or_path}")

return model


def cache_merged_model(config: ModelConfig, logger: Optional[Logger] = None) -> str:
logger = logger or get_logger(__file__, level="INFO")

base_local_path = cached_path(config.name_or_path)
adapter_config = get_adapter_config(config)
if not adapter_config:
logger.info("No adapter config found; using base model")
return base_local_path

local_fn = resource_to_filename(json.dumps({"adapter": adapter_config, "model": config.name_or_path}))
merged_local_path = f"{get_cache_dir()}/{local_fn}"

if not exists(merged_local_path):
model = load_model(config=config, logger=logger)
tokenizer = AutoTokenizer.from_pretrained(base_local_path)

logger.info(f"Saving merged model to {merged_local_path}")
model.save_pretrained(merged_local_path)
tokenizer.save_pretrained(merged_local_path)

return merged_local_path
Loading

0 comments on commit 7d2c447

Please sign in to comment.