Skip to content

Commit

Permalink
Use create_tmp_dir when creating a temporary directory (mlflow#10462)
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <[email protected]>
Signed-off-by: mlflow-automation <[email protected]>
Co-authored-by: mlflow-automation <[email protected]>
Signed-off-by: swathi <[email protected]>
  • Loading branch information
2 people authored and KonakanchiSwathi committed Nov 29, 2023
1 parent 7a1cfe4 commit 6b39495
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mlflow/data/http_dataset_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import posixpath
import re
import tempfile
from typing import Any, Dict
from urllib.parse import urlparse

from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.file_utils import create_tmp_dir
from mlflow.utils.rest_utils import augmented_raise_for_status, cloud_storage_http_request


Expand Down Expand Up @@ -63,7 +63,7 @@ def load(self, dst_path=None) -> str:
basename = "dataset_source"

if dst_path is None:
dst_path = tempfile.mkdtemp()
dst_path = create_tmp_dir()

dst_path = os.path.join(dst_path, basename)
with open(dst_path, "wb") as f:
Expand Down
5 changes: 2 additions & 3 deletions mlflow/store/artifact/artifact_repo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import posixpath
import tempfile
from abc import ABCMeta, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed

Expand All @@ -10,7 +9,7 @@
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
from mlflow.utils.annotations import developer_stable
from mlflow.utils.file_utils import ArtifactProgressBar
from mlflow.utils.file_utils import ArtifactProgressBar, create_tmp_dir
from mlflow.utils.validation import bad_path_message, path_not_unique

# Constants used to determine max level of parallelism to use while uploading/downloading artifacts.
Expand Down Expand Up @@ -172,7 +171,7 @@ def download_artifacts(self, artifact_path, dst_path=None):
error_code=INVALID_PARAMETER_VALUE,
)
else:
dst_path = tempfile.mkdtemp()
dst_path = create_tmp_dir()

def _download_file(src_artifact_path, dst_local_dir_path):
dst_local_file_path = self._create_download_destination(
Expand Down
24 changes: 16 additions & 8 deletions mlflow/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,19 +833,27 @@ def _handle_readonly_on_windows(func, path, exc_info):
func(path)


def create_tmp_dir():
def _get_tmp_dir():
from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime

if is_in_databricks_runtime() and get_repl_id() is not None:
if is_in_databricks_runtime():
try:
repl_local_tmp_dir = _get_dbutils().entry_point.getReplLocalTempDir()
return _get_dbutils().entry_point.getReplLocalTempDir()
except Exception:
repl_local_tmp_dir = os.path.join("/tmp", "repl_tmp_data", get_repl_id())
pass

os.makedirs(repl_local_tmp_dir, exist_ok=True)
return tempfile.mkdtemp(dir=repl_local_tmp_dir)
else:
return tempfile.mkdtemp()
if repl_id := get_repl_id():
return os.path.join("/tmp", "repl_tmp_data", repl_id)

return None


def create_tmp_dir():
if directory := _get_tmp_dir():
os.makedirs(directory, exist_ok=True)
return tempfile.mkdtemp(dir=directory)

return tempfile.mkdtemp()


@cache_return_value_per_process
Expand Down

0 comments on commit 6b39495

Please sign in to comment.