Skip to content

Commit

Permalink
Add Beaker.session() context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jun 9, 2022
1 parent 0782c04 commit 739869e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `Beaker.session()` context manager for improving performance when calling a series of
client methods in a row.

## [v1.3.0](https://github.com/allenai/beaker-py/releases/tag/v1.3.0) - 2022-05-31

### Added
Expand Down
44 changes: 43 additions & 1 deletion beaker/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Optional
from contextlib import contextmanager
from typing import Generator, Optional

import docker
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

from .config import Config
from .data_model import *
Expand Down Expand Up @@ -36,13 +40,18 @@ class Beaker:
"""

RECOVERABLE_SERVER_ERROR_CODES = (502, 503, 504)
MAX_RETRIES = 5
API_VERSION = "v3"

def __init__(self, config: Config, check_for_upgrades: bool = True):
# See if there's a newer version, and if so, suggest that the user upgrades.
if check_for_upgrades:
self._check_for_upgrades()

self._config = config
self._docker: Optional[docker.DockerClient] = None
self._session: Optional[requests.Session] = None

# Initialize service clients:
self._account = AccountClient(self)
Expand Down Expand Up @@ -126,6 +135,39 @@ def from_env(cls, check_for_upgrades: bool = True, **overrides) -> "Beaker":
"""
return cls(Config.from_env(**overrides), check_for_upgrades=check_for_upgrades)

def _make_session(self) -> requests.Session:
session = requests.Session()
retries = Retry(
total=self.MAX_RETRIES,
backoff_factor=1,
status_forcelist=self.RECOVERABLE_SERVER_ERROR_CODES,
)
session.mount("https://", HTTPAdapter(max_retries=retries))
return session

@contextmanager
def session(self) -> Generator[None, None, None]:
"""
A context manager that forces the Beaker client to reuse a single :class:`requests.Session`
for all HTTP requests to the Beaker server.
This can improve performance when calling a series of a client methods in a row.
:examples:
>>> with beaker.session():
... n_images = len(beaker.workspace.images())
... n_datasets = len(beaker.workspace.datasets())
"""
session = self._make_session()
try:
self._session = session
yield None
finally:
self._session = None
session.close()

@property
def config(self) -> Config:
"""
Expand Down
30 changes: 9 additions & 21 deletions beaker/services/service_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import io
import json
import urllib.parse
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import docker
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

from beaker.config import Config
from beaker.data_model import *
Expand All @@ -19,13 +16,9 @@


class ServiceClient:
RECOVERABLE_SERVER_ERROR_CODES = (502, 503, 504)
MAX_RETRIES = 5
API_VERSION = "v3"

def __init__(self, beaker: "Beaker"):
self.beaker = beaker
self._base_url = f"{self.config.agent_address}/api/{self.API_VERSION}"
self._base_url = f"{self.config.agent_address}/api/{self.beaker.API_VERSION}"

@property
def config(self) -> Config:
Expand All @@ -35,17 +28,6 @@ def config(self) -> Config:
def docker(self) -> docker.DockerClient:
return self.beaker.docker

@contextmanager
def _session_with_backoff(self) -> Generator[requests.Session, None, None]:
session = requests.Session()
retries = Retry(
total=self.MAX_RETRIES,
backoff_factor=1,
status_forcelist=self.RECOVERABLE_SERVER_ERROR_CODES,
)
session.mount("https://", HTTPAdapter(max_retries=retries))
yield session

def request(
self,
resource: str,
Expand All @@ -58,7 +40,7 @@ def request(
base_url: Optional[str] = None,
stream: bool = False,
) -> requests.Response:
with self._session_with_backoff() as session:
def make_request(session: requests.Session) -> requests.Response:
# Build URL.
url = f"{base_url or self._base_url}/{resource}"
if query is not None:
Expand Down Expand Up @@ -110,6 +92,12 @@ def request(

return response

if self.beaker._session is not None:
return make_request(self.beaker._session)
else:
with self.beaker._make_session() as session:
return make_request(session)

def resolve_cluster_name(self, cluster_name: str) -> str:
if "/" not in cluster_name:
if self.config.default_org is not None:
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"docker": ("https://docker-py.readthedocs.io/en/stable/", None),
"requests": ("https://requests.readthedocs.io/en/stable/", None),
}

# By default, sort documented members by type within classes and modules.
Expand Down

0 comments on commit 739869e

Please sign in to comment.