Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Mac OSX on Apple Silicon #465

Merged
merged 21 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions smartsim/_core/_install/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BuildError(Exception):

class Architecture(enum.Enum):
X64 = ("x86_64", "amd64")
ARM64 = ("arm64",)

@classmethod
def from_str(cls, string: str, /) -> "Architecture":
Expand Down Expand Up @@ -438,6 +439,8 @@ def fail_to_format(reason: str) -> BuildError: # pragma: no cover
raise fail_to_format(f"Unknown operating system: {self._os}")
if self._architecture == Architecture.X64:
arch = "x64"
elif self._architecture == Architecture.ARM64:
arch = "arm64v8"
else: # pragma: no cover
raise fail_to_format(f"Unknown architecture: {self._architecture}")
return self.rai_build_path / f"deps/{os_}-{arch}-{device}"
Expand All @@ -454,7 +457,7 @@ def _get_deps_to_fetch_for(
# is used as script in the SmartSim `setup.py`.
fetchable_deps: t.Sequence[t.Tuple[bool, _RAIBuildDependency]] = (
(True, _DLPackRepository("v0.5_RAI")),
(self.fetch_torch, _PTArchive(os_, device, "2.0.1")),
(self.fetch_torch, choose_PT_variant(os_, arch, device, "2.0.1")),
(self.fetch_tf, _TFArchive(os_, arch, device, "2.13.1")),
(self.fetch_onnx, _ORTArchive(os_, device, "1.16.3")),
)
Expand Down Expand Up @@ -760,31 +763,13 @@ def _extract_download(
zip_file.extractall(target)


@t.final
@dataclass(frozen=True)
class _PTArchive(_WebZip, _RAIBuildDependency):
os_: OperatingSystem
architecture: Architecture
device: TDeviceStr
version: str

@property
def url(self) -> str:
if self.os_ == OperatingSystem.LINUX:
if self.device == "gpu":
pt_build = "cu117"
else:
pt_build = "cpu"
# pylint: disable-next=line-too-long
libtorch_arch = f"libtorch-cxx11-abi-shared-without-deps-{self.version}%2B{pt_build}.zip"
elif self.os_ == OperatingSystem.DARWIN:
if self.device == "gpu":
raise BuildError("RedisAI does not currently support GPU on Macos")
pt_build = "cpu"
libtorch_arch = f"libtorch-macos-{self.version}.zip"
else:
raise BuildError(f"Unexpected OS for the PT Archive: {self.os_}")
return f"https://download.pytorch.org/libtorch/{pt_build}/{libtorch_arch}"

@property
def __rai_dependency_name__(self) -> str:
return f"libtorch@{self.url}"
Expand All @@ -797,6 +782,49 @@ def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path:
return target


@t.final
class _PTArchive_Linux(_PTArchive):
@property
def url(self) -> str:
if self.device == "gpu":
pt_build = "cu117"
else:
pt_build = "cpu"
# pylint: disable-next=line-too-long
libtorch_archive= f"libtorch-cxx11-abi-shared-without-deps-{self.version}%2B{pt_build}.zip"
return f"https://download.pytorch.org/libtorch/{pt_build}/{libtorch_archive}"


@t.final
class _PTArchive_MacOSX(_PTArchive):
@property
def url(self) -> str:
if self.device == "gpu":
raise BuildError("RedisAI does not currently support GPU on Mac OSX")
if self.architecture == Architecture.X64:
libtorch_archive= f"libtorch-macos-{self.version}.zip"
return f"https://download.pytorch.org/libtorch/{pt_build}/{libtorch_archive}"
elif self.architecture == Architecture.ARM64:
libtorch_archive = f"libtorch-macos-arm64-{self.version}.zip"
root_url = "https://github.com/CrayLabs/ml_lib_builder/releases/download/v0.1/"
out = f"{root_url}/{libtorch_archive}"
return out


def choose_PT_variant(
os_: OperatingSystem,
device: TDeviceStr,
arch: Architecture,
version: str
) -> t.Union[_PTArchive_Linux, _PTArchive_MacOSX]:
if os_ == OperatingSystem.DARWIN:
return _PTArchive_MacOSX(os_, device, arch, version)
elif os_ == OperatingSystem.LINUX:
return _PTArchive_Linux(os_, device, arch, version)
else:
raise BuildError(f"Unsupported OS for pyTorch: {os_}")


@t.final
@dataclass(frozen=True)
class _TFArchive(_WebTGZ, _RAIBuildDependency):
Expand Down
53 changes: 52 additions & 1 deletion tests/install/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
import pytest

import smartsim._core._install.builder as build
from smartsim._core._install.buildenv import RedisAIVersion

# The tests in this file belong to the group_a group
pytestmark = pytest.mark.group_a

RAI_versions = RedisAIVersion("1.2.7")

for_each_device = pytest.mark.parametrize("device", ["cpu", "gpu"])

Expand Down Expand Up @@ -66,7 +68,7 @@ def test_rai_builder_raises_on_unsupported_op_sys(monkeypatch, mock_os):
"mock_arch",
[
pytest.param(arch_, id=f"arch='{arch_}'")
for arch_ in ("i386", "i686", "i86pc", "aarch64", "arm64", "armv7l", "")
for arch_ in ("i386", "i686", "i86pc", "aarch64", "armv7l", "")
],
)
def test_rai_builder_raises_on_unsupported_architecture(monkeypatch, mock_arch):
Expand Down Expand Up @@ -205,3 +207,52 @@ def _some_long_io_op(_):
build._threaded_map(_some_long_io_op, [])
end = time.time()
assert end - start < sleep_duration

def test_correct_pt_variant_os():
# Check that all Linux variants return Linux
for linux_variant in build.OperatingSystem.LINUX.value:
os_ = build.OperatingSystem.from_str(linux_variant)
assert isinstance(
build.choose_PT_variant(os_, "x86_64", "cpu", RAI_versions.torch),
build._PTArchive_Linux
)
# Check that ARM64 and X86_64 Mac OSX return the Mac variant
all_archs = (build.Architecture.ARM64, build.Architecture.X64)
for arch in all_archs:
os_ = build.OperatingSystem.DARWIN
assert isinstance(
build.choose_PT_variant(os_, arch, "cpu", RAI_versions.torch),
build._PTArchive_MacOSX
)

def test_PTArchive_MacOSX_url():
os_ = build.OperatingSystem.DARWIN
arch = build.Architecture.X64
pt_version = RAI_versions.torch

pt_linux_cpu = build._PTArchive_Linux(
os_,
build.Architecture.X64,
"cpu",
pt_version
)
x64_prefix = "https://download.pytorch.org/libtorch/"
assert x64_prefix in pt_linux_cpu.url

pt_macosx_cpu = build._PTArchive_MacOSX(
os_,
build.Architecture.ARM64,
"cpu",
pt_version
)
arm64_prefix = "https://github.com/CrayLabs/ml_lib_builder/releases/download/"
assert arm64_prefix in pt_macosx_cpu.url

def test_PTArchive_MacOSX_gpu_error():
with pytest.raises(build.BuildError, match="support GPU on Mac OSX"):
build._PTArchive_MacOSX(
build.OperatingSystem.DARWIN,
build.Architecture.ARM64,
"gpu",
RAI_versions.torch
).url
Loading