diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index cfc8ccc8..01181633 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -13,13 +13,10 @@ #include #include #include -#include -#include +#include #include #include -#include "errors.hpp" - namespace mscclpp { #define MSCCLPP_UNIQUE_ID_BYTES 128 @@ -303,23 +300,6 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) { return TransportFlags(transport1) ^ transport2; } -/// Get the number of available InfiniBand devices. -/// -/// @return The number of available InfiniBand devices. -int getIBDeviceCount(); - -/// Get the name of the InfiniBand device associated with the specified transport. -/// -/// @param ibTransport The InfiniBand transport to get the device name for. -/// @return The name of the InfiniBand device associated with the specified transport. -std::string getIBDeviceName(Transport ibTransport); - -/// Get the InfiniBand transport associated with the specified device name. -/// -/// @param ibDeviceName The name of the InfiniBand device to get the transport for. -/// @return The InfiniBand transport associated with the specified device name. -Transport getIBTransportByDeviceName(const std::string& ibDeviceName); - class Context; class Connection; diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index 5d76983e..6848688e 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -6,6 +6,7 @@ #include #include +#include #include namespace mscclpp { diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 8b7d8b19..30655300 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -193,6 +193,9 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind); } +/// Check if NVLink SHARP (NVLS) is supported. +/// +/// @return True if NVLink SHARP (NVLS) is supported, false otherwise. bool isNvlsSupported(); /// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index c8ef3d27..17721155 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_UTILS_HPP_ #include +#include #include namespace mscclpp { @@ -37,6 +38,23 @@ struct ScopedTimer : public Timer { std::string getHostName(int maxlen, const char delim); +/// Get the number of available InfiniBand devices. +/// +/// @return The number of available InfiniBand devices. +int getIBDeviceCount(); + +/// Get the name of the InfiniBand device associated with the specified transport. +/// +/// @param ibTransport The InfiniBand transport to get the device name for. +/// @return The name of the InfiniBand device associated with the specified transport. +std::string getIBDeviceName(Transport ibTransport); + +/// Get the InfiniBand transport associated with the specified device name. +/// +/// @param ibDeviceName The name of the InfiniBand device to get the transport for. +/// @return The InfiniBand transport associated with the specified device name. +Transport getIBTransportByDeviceName(const std::string& ibDeviceName); + } // namespace mscclpp #endif // MSCCLPP_UTILS_HPP_ diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 410ad246..c66624ca 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -4,6 +4,13 @@ import os as _os from ._mscclpp import ( + ErrorCode, + BaseError, + Error, + SysError, + CudaError, + CuError, + IbError, Communicator, Connection, connect_nvls_collective, diff --git a/python/mscclpp/error_py.cpp b/python/mscclpp/error_py.cpp index af78ac88..ff532d10 100644 --- a/python/mscclpp/error_py.cpp +++ b/python/mscclpp/error_py.cpp @@ -9,7 +9,19 @@ namespace nb = nanobind; using namespace mscclpp; -void register_error(nb::module_& m) { +#define REGISTER_EXCEPTION_TRANSLATOR(name_) \ + nb::register_exception_translator( \ + [](const std::exception_ptr &p, void *payload) { \ + try { \ + std::rethrow_exception(p); \ + } catch (const name_ &e) { \ + PyErr_SetObject(reinterpret_cast(payload), \ + PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \ + } \ + }, \ + m.attr(#name_).ptr()); + +void register_error(nb::module_ &m) { nb::enum_(m, "ErrorCode") .value("SystemError", ErrorCode::SystemError) .value("InternalError", ErrorCode::InternalError) @@ -19,24 +31,21 @@ void register_error(nb::module_& m) { .value("Aborted", ErrorCode::Aborted) .value("ExecutorError", ErrorCode::ExecutorError); - nb::class_(m, "BaseError") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")) - .def("get_error_code", &BaseError::getErrorCode) - .def("what", &BaseError::what); + nb::exception(m, "BaseError"); + REGISTER_EXCEPTION_TRANSLATOR(BaseError); - nb::class_(m, "Error") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")) - .def("get_error_code", &Error::getErrorCode); + nb::exception(m, "Error", m.attr("BaseError").ptr()); + REGISTER_EXCEPTION_TRANSLATOR(Error); - nb::class_(m, "SysError") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")); + nb::exception(m, "SysError", m.attr("BaseError").ptr()); + REGISTER_EXCEPTION_TRANSLATOR(SysError); - nb::class_(m, "CudaError") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")); + nb::exception(m, "CudaError", m.attr("BaseError").ptr()); + REGISTER_EXCEPTION_TRANSLATOR(CudaError); - nb::class_(m, "CuError") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")); + nb::exception(m, "CuError", m.attr("BaseError").ptr()); + REGISTER_EXCEPTION_TRANSLATOR(CuError); - nb::class_(m, "IbError") - .def(nb::init(), nb::arg("message"), nb::arg("errorCode")); + nb::exception(m, "IbError", m.attr("BaseError").ptr()); + REGISTER_EXCEPTION_TRANSLATOR(IbError); } diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 976d7436..691a169a 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -12,6 +12,8 @@ import pytest from mscclpp import ( + ErrorCode, + Error, DataType, EndpointConfig, ExecutionPlan, @@ -44,7 +46,7 @@ def all_ranks_on_the_same_node(mpi_group: MpiGroup): @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("ifIpPortTrio", ["eth0:localhost:50000", ethernet_interface_name, ""]) +@pytest.mark.parametrize("ifIpPortTrio", [f"{ethernet_interface_name}:localhost:50000", ethernet_interface_name, ""]) def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str): if (ethernet_interface_name in ni.interfaces()) is False: pytest.skip(f"{ethernet_interface_name} is not an interface to use on this node") @@ -146,7 +148,12 @@ def create_group_and_connection(mpi_group: MpiGroup, transport: str): if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False: pytest.skip("cannot use nvlink/nvls for cross node") group = mscclpp_comm.CommGroup(mpi_group.comm) - connection = create_connection(group, transport) + try: + connection = create_connection(group, transport) + except Error as e: + if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage: + pytest.skip("IB not supported on this node") + raise return group, connection diff --git a/src/ib.cc b/src/ib.cc index d9d72d1a..107abc77 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -367,10 +367,10 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { - throw mscclpp::Error("No active port found", ErrorCode::InternalError); + throw mscclpp::Error("No active port found", ErrorCode::InvalidUsage); } } else if (!this->isPortUsable(port)) { - throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); + throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); } qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend)); return qps.back().get();