diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 53a507f2d79a..89548b74866b 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -148,8 +148,20 @@ def get_cuda_version(cuda_path): with open(version_file_path) as f: version_str = f.readline().replace("\n", "").replace("\r", "") return float(version_str.split(" ")[2][:2]) - except: - raise RuntimeError("Cannot read cuda version file") + except FileNotFoundError: + pass + + cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + out = py_str(out) + if proc.returncode == 0: + release_line = [l for l in out.split("\n") if "release" in l][0] + release_fields = [s.strip() for s in release_line.split(",")] + release_version = [f[1:] for f in release_fields if f.startswith("V")][0] + major_minor = ".".join(release_version.split(".")[:2]) + return float(major_minor) + raise RuntimeError("Cannot read cuda version file") @tvm._ffi.register_func("tvm_callback_libdevice_path") @@ -174,7 +186,7 @@ def find_libdevice_path(arch): selected_ver = 0 selected_path = None cuda_ver = get_cuda_version(cuda_path) - if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0): + if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1): path = os.path.join(lib_path, "libdevice.10.bc") else: for fn in os.listdir(lib_path): @@ -219,6 +231,7 @@ def parse_compute_version(compute_version): minor = int(split_ver[1]) return major, minor except (IndexError, ValueError) as err: + # pylint: disable=raise-missing-from raise RuntimeError("Compute version parsing error: " + str(err))