Skip to content

Commit

Permalink
style: format
Browse files Browse the repository at this point in the history
  • Loading branch information
brosoul committed Nov 7, 2024
1 parent 12ed3a5 commit 3854e14
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 97 deletions.
10 changes: 5 additions & 5 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,15 +1175,14 @@ def __init__(self, load_config: LoadConfig):
else:
self.stream_loader_config = StreamConfig(
**load_config.model_loader_extra_config)

self.stream_model = self.stream_loader_config.construct_stream_model()

def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.stream_loader_config.verify_with_model_config(model_config)
self.stream_loader_config.verify_with_parallel_config(parallel_config)


def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
Expand All @@ -1197,9 +1196,10 @@ def load_model(self, *, model_config: ModelConfig,
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)

model.load_weights(self.stream_model.get_weights_iterator(device_config.device))


model.load_weights(
self.stream_model.get_weights_iterator(device_config.device))

return model.eval()

def download_model(self, model_config: ModelConfig) -> None:
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/model_loader/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.


SUPPORTED_STREAM_STORAGE = ("s3", "tos", "local")
DOWNLOAD_CACHE_DIR = ".cache"
39 changes: 20 additions & 19 deletions vllm/model_executor/model_loader/stream/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@


class LoadFile:

def __init__(self, file_source: str) -> None:
self.file_source = file_source

Expand All @@ -46,6 +47,7 @@ def download(self, target_dir):


class LocalFile(LoadFile):

def __init__(self, file: Union[str, Path]) -> None:
if not Path(file).exists():
raise ValueError(f"file {file} not exist")
Expand All @@ -55,9 +57,8 @@ def __init__(self, file: Union[str, Path]) -> None:

def load_whole_file(self, num_threads: int = 1):
if num_threads != 1:
logger.warning(
"num_threads %s is not supported for local file.", num_threads
)
logger.warning("num_threads %s is not supported for local file.",
num_threads)

tensor_bytes = np.memmap(
self.file,
Expand All @@ -80,6 +81,7 @@ def load_to_buffer(self, offset: int, count: int):


class RemoteFile(LoadFile):

def __init__(self, file: str, file_source: str) -> None:
self.file = file
super().__init__(file_source=file_source)
Expand All @@ -93,6 +95,7 @@ def download_file(self, target_dir: str, num_threads: int = 1):


class S3File(RemoteFile):

def __init__(
self,
scheme: str,
Expand Down Expand Up @@ -121,15 +124,15 @@ def __init__(
self.s3_client.head_object(Bucket=bucket_name, Key=bucket_path)
except Exception as e:
raise ValueError(
f"S3 bucket path {bucket_path} not exist for {e}."
) from e
f"S3 bucket path {bucket_path} not exist for {e}.") from e

file = scheme + "://" + bucket_name + "/" + bucket_path
super().__init__(file=file, file_source=scheme)

@classmethod
def from_uri(cls, file_uri: str, **kwargs):
scheme, bucket_name, bucket_path = _parse_bucket_info_from_uri(file_uri)
scheme, bucket_name, bucket_path = _parse_bucket_info_from_uri(
file_uri)
cls(scheme, bucket_name, bucket_path, **kwargs)

def load_whole_file(self, num_threads: int = 1):
Expand All @@ -150,9 +153,9 @@ def load_whole_file(self, num_threads: int = 1):

def load_to_bytes(self, offset: int, count: int):
range_header = f"bytes={offset}-{offset+count-1}"
resp = self.s3_client.get_object(
Bucket=self.bucket_name, Key=self.bucket_path, Range=range_header
)
resp = self.s3_client.get_object(Bucket=self.bucket_name,
Key=self.bucket_path,
Range=range_header)
return read_to_bytes_io(resp.get("Body"))

def download_file(
Expand All @@ -162,13 +165,11 @@ def download_file(
force_download: bool = False,
):
try:
meta_data = self.s3_client.head_object(
Bucket=self.bucket_name, Key=self.bucket_path
)
meta_data = self.s3_client.head_object(Bucket=self.bucket_name,
Key=self.bucket_path)
except Exception as e:
raise ValueError(
"S3 bucket path %s not exist for %s.", self.bucket_path, e
) from e
raise ValueError("S3 bucket path %s not exist for %s.",
self.bucket_path, e) from e

# ensure target dir exist
target_path = Path(target_dir)
Expand All @@ -181,10 +182,10 @@ def download_file(
etag = meta_data.get("ETag", "")
file_size = meta_data.get("ContentLength", 0)

meta_data_file = meta_file(local_path=target_path, file_name=_file_name)
if not need_to_download(
local_file, meta_data_file, file_size, etag, force_download
):
meta_data_file = meta_file(local_path=target_path,
file_name=_file_name)
if not need_to_download(local_file, meta_data_file, file_size, etag,
force_download):
logger.info("file `%s` already exist.", self.bucket_path)
return

Expand Down
53 changes: 24 additions & 29 deletions vllm/model_executor/model_loader/stream/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def get_safetensors_metas(file: LoadFile):
dtype=tensor_meta["dtype"],
shape=tensor_meta["shape"],
data_offsets=tensor_meta["data_offsets"],
)
)
))
# Ensure tensors chunks could be split continuously
return sorted(metas, key=lambda obj: obj.real_offset)


class StreamLoader:

def __init__(
self,
file: LoadFile,
Expand Down Expand Up @@ -89,31 +89,30 @@ def _tensors_reader(
# TODO use stream nonblocking IO
for tensor_meta in tensor_metas:
tensor_buffer = self.file.load_to_buffer(
offset=tensor_meta.real_offset, count=tensor_meta.count
)
offset=tensor_meta.real_offset, count=tensor_meta.count)
tensor = torch.frombuffer(
tensor_buffer, dtype=tensor_meta.dtype
).view(tensor_meta.shape)
tensor_buffer, dtype=tensor_meta.dtype).view(tensor_meta.shape)
if is_cuda:
tensor = tensor.to(device, non_blocking=True)
tensor_meta.set_tensor(tensor)
transfer_out_queue.put(tensor_meta)

def get_weights_iterator(
self, device: Union[torch.device, str] = "cpu"
self,
device: Union[torch.device, str] = "cpu"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensors_per_reader: List[Tuple[TensorMeta, ...]] = (
split_continue_tensors(self.tensors_metas, self.num_thread)
)
tensors_per_reader: List[Tuple[TensorMeta,
...]] = (split_continue_tensors(
self.tensors_metas,
self.num_thread))

effective_num_readers = len(tensors_per_reader)
self._reader_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=effective_num_readers,
thread_name_prefix="SafetensorsReader",
)
transfer_out_queue: queue.SimpleQueue[Union[Exception, TensorMeta]] = (
queue.SimpleQueue()
) # type: ignore
queue.SimpleQueue()) # type: ignore
futures: List[concurrent.futures.Future] = []

barrier = threading.Barrier(effective_num_readers)
Expand All @@ -131,28 +130,27 @@ def get_weights_iterator(

try:
for _ in range(len(self.tensors_metas)):
tensor_meta: Union[TensorMeta, Exception] = (
transfer_out_queue.get(timeout=3600)
)
tensor_meta: Union[TensorMeta,
Exception] = (transfer_out_queue.get(
timeout=3600))
if isinstance(tensor_meta, Exception):
raise tensor_meta
yield tensor_meta.name, tensor_meta.tensor
except BaseException:
raise

def get_weights_iterator_wo_threads(
self, device: Union[torch.device, str] = "cpu"
self,
device: Union[torch.device, str] = "cpu"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
device = torch.device(device)
is_cuda = device.type == "cuda"
# TODO use stream nonblocking IO
for tensor_meta in self.tensors_metas:
tensor_buffer = self.file.load_to_bytes(
offset=tensor_meta.real_offset, count=tensor_meta.count
)
offset=tensor_meta.real_offset, count=tensor_meta.count)
tensor = torch.frombuffer(
tensor_buffer, dtype=tensor_meta.dtype
).view(tensor_meta.shape)
tensor_buffer, dtype=tensor_meta.dtype).view(tensor_meta.shape)

if is_cuda:
tensor = tensor.to(device, non_blocking=True)
Expand All @@ -173,8 +171,7 @@ class StreamModel:

def __post_init__(self):
scheme, bucket_name, bucket_path = _parse_bucket_info_from_uri(
self.model_uri
)
self.model_uri)
if not bucket_path.endswith("/"):
bucket_path += "/"
self.model_source_type = scheme
Expand All @@ -198,8 +195,7 @@ def __post_init__(self):
num_threads=self.num_threads,
)
objects_out = self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Delimiter="/", Prefix=bucket_path
)
Bucket=self.bucket_name, Delimiter="/", Prefix=bucket_path)
files = [
str(content.get("Key"))
for content in objects_out.get("Contents", [])
Expand All @@ -215,9 +211,9 @@ def __post_init__(self):
if len(self.safetensors_files) == 0:
raise ValueError(f"no safetensors file found in {self.model_uri}")

def download_config(
self, target_dir: str, force_download: bool = False
) -> Path:
def download_config(self,
target_dir: str,
force_download: bool = False) -> Path:
if self.model_source_type == "local":
logger.info("local config no need to download")
return Path(self.model_uri)
Expand Down Expand Up @@ -258,6 +254,5 @@ def get_weights_iterator(self, device: Union[torch.device, str] = "cpu"):
use_direct_io=self.use_direct_io,
)
for name, tensor in safetensors_loader.get_weights_iterator(
device=device
):
device=device):
yield name, tensor
Loading

0 comments on commit 3854e14

Please sign in to comment.