Skip to content

Commit df1a7e0

Browse files
stainless-app[bot]stainless-bot
authored andcommitted
feat(api): update via SDK Studio (#8)
1 parent a29cd55 commit df1a7e0

File tree

4 files changed

+109
-13
lines changed

4 files changed

+109
-13
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ from luma_ai import LumaAI
3131

3232
client = LumaAI(
3333
auth_token="My Auth Token",
34+
# or 'production' | 'staging' | 'localhost'; defaults to "production".
35+
environment="production_api",
3436
)
3537

3638
generation = client.generations.create(
@@ -49,6 +51,8 @@ from luma_ai import AsyncLumaAI
4951

5052
client = AsyncLumaAI(
5153
auth_token="My Auth Token",
54+
# or 'production' | 'staging' | 'localhost'; defaults to "production".
55+
environment="production_api",
5256
)
5357

5458

src/luma_ai/__init__.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
from . import types
44
from ._types import NOT_GIVEN, NoneType, NotGiven, Transport, ProxiesTypes
55
from ._utils import file_from_path
6-
from ._client import Client, LumaAI, Stream, Timeout, Transport, AsyncClient, AsyncLumaAI, AsyncStream, RequestOptions
6+
from ._client import (
7+
ENVIRONMENTS,
8+
Client,
9+
LumaAI,
10+
Stream,
11+
Timeout,
12+
Transport,
13+
AsyncClient,
14+
AsyncLumaAI,
15+
AsyncStream,
16+
RequestOptions,
17+
)
718
from ._models import BaseModel
819
from ._version import __title__, __version__
920
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
@@ -58,6 +69,7 @@
5869
"AsyncStream",
5970
"LumaAI",
6071
"AsyncLumaAI",
72+
"ENVIRONMENTS",
6173
"file_from_path",
6274
"BaseModel",
6375
"DEFAULT_TIMEOUT",

src/luma_ai/_client.py

+72-12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import Any, Union, Mapping
7-
from typing_extensions import Self, override
6+
from typing import Any, Dict, Union, Mapping, cast
7+
from typing_extensions import Self, Literal, override
88

99
import httpx
1010

@@ -33,6 +33,7 @@
3333
)
3434

3535
__all__ = [
36+
"ENVIRONMENTS",
3637
"Timeout",
3738
"Transport",
3839
"ProxiesTypes",
@@ -44,6 +45,13 @@
4445
"AsyncClient",
4546
]
4647

48+
ENVIRONMENTS: Dict[str, str] = {
49+
"production": "http://api.lumalabs.ai/dream-machine/v1alpha",
50+
"production_api": "http://internal-api.virginia.labs.lumalabs.ai/dream-machine/v1alpha",
51+
"staging": "http://internal-api.sandbox.labs.lumalabs.ai/dream-machine/v1alpha",
52+
"localhost": "http://localhost:9600/dream-machine/v1alpha",
53+
}
54+
4755

4856
class LumaAI(SyncAPIClient):
4957
ping: resources.PingResource
@@ -54,11 +62,14 @@ class LumaAI(SyncAPIClient):
5462
# client options
5563
auth_token: str
5664

65+
_environment: Literal["production", "production_api", "staging", "localhost"] | NotGiven
66+
5767
def __init__(
5868
self,
5969
*,
6070
auth_token: str,
61-
base_url: str | httpx.URL | None = None,
71+
environment: Literal["production", "production_api", "staging", "localhost"] | NotGiven = NOT_GIVEN,
72+
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
6273
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
6374
max_retries: int = DEFAULT_MAX_RETRIES,
6475
default_headers: Mapping[str, str] | None = None,
@@ -80,10 +91,31 @@ def __init__(
8091
"""Construct a new synchronous luma_ai client instance."""
8192
self.auth_token = auth_token
8293

83-
if base_url is None:
84-
base_url = os.environ.get("LUMA_AI_BASE_URL")
85-
if base_url is None:
86-
base_url = f"http://api.lumalabs.ai/dream-machine/v1alpha"
94+
self._environment = environment
95+
96+
base_url_env = os.environ.get("LUMA_AI_BASE_URL")
97+
if is_given(base_url) and base_url is not None:
98+
# cast required because mypy doesn't understand the type narrowing
99+
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
100+
elif is_given(environment):
101+
if base_url_env and base_url is not None:
102+
raise ValueError(
103+
"Ambiguous URL; The `LUMA_AI_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
104+
)
105+
106+
try:
107+
base_url = ENVIRONMENTS[environment]
108+
except KeyError as exc:
109+
raise ValueError(f"Unknown environment: {environment}") from exc
110+
elif base_url_env is not None:
111+
base_url = base_url_env
112+
else:
113+
self._environment = environment = "production"
114+
115+
try:
116+
base_url = ENVIRONMENTS[environment]
117+
except KeyError as exc:
118+
raise ValueError(f"Unknown environment: {environment}") from exc
87119

88120
super().__init__(
89121
version=__version__,
@@ -125,6 +157,7 @@ def copy(
125157
self,
126158
*,
127159
auth_token: str | None = None,
160+
environment: Literal["production", "production_api", "staging", "localhost"] | None = None,
128161
base_url: str | httpx.URL | None = None,
129162
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
130163
http_client: httpx.Client | None = None,
@@ -160,6 +193,7 @@ def copy(
160193
return self.__class__(
161194
auth_token=auth_token or self.auth_token,
162195
base_url=base_url or self.base_url,
196+
environment=environment or self._environment,
163197
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
164198
http_client=http_client,
165199
max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -215,11 +249,14 @@ class AsyncLumaAI(AsyncAPIClient):
215249
# client options
216250
auth_token: str
217251

252+
_environment: Literal["production", "production_api", "staging", "localhost"] | NotGiven
253+
218254
def __init__(
219255
self,
220256
*,
221257
auth_token: str,
222-
base_url: str | httpx.URL | None = None,
258+
environment: Literal["production", "production_api", "staging", "localhost"] | NotGiven = NOT_GIVEN,
259+
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
223260
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
224261
max_retries: int = DEFAULT_MAX_RETRIES,
225262
default_headers: Mapping[str, str] | None = None,
@@ -241,10 +278,31 @@ def __init__(
241278
"""Construct a new async luma_ai client instance."""
242279
self.auth_token = auth_token
243280

244-
if base_url is None:
245-
base_url = os.environ.get("LUMA_AI_BASE_URL")
246-
if base_url is None:
247-
base_url = f"http://api.lumalabs.ai/dream-machine/v1alpha"
281+
self._environment = environment
282+
283+
base_url_env = os.environ.get("LUMA_AI_BASE_URL")
284+
if is_given(base_url) and base_url is not None:
285+
# cast required because mypy doesn't understand the type narrowing
286+
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
287+
elif is_given(environment):
288+
if base_url_env and base_url is not None:
289+
raise ValueError(
290+
"Ambiguous URL; The `LUMA_AI_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
291+
)
292+
293+
try:
294+
base_url = ENVIRONMENTS[environment]
295+
except KeyError as exc:
296+
raise ValueError(f"Unknown environment: {environment}") from exc
297+
elif base_url_env is not None:
298+
base_url = base_url_env
299+
else:
300+
self._environment = environment = "production"
301+
302+
try:
303+
base_url = ENVIRONMENTS[environment]
304+
except KeyError as exc:
305+
raise ValueError(f"Unknown environment: {environment}") from exc
248306

249307
super().__init__(
250308
version=__version__,
@@ -286,6 +344,7 @@ def copy(
286344
self,
287345
*,
288346
auth_token: str | None = None,
347+
environment: Literal["production", "production_api", "staging", "localhost"] | None = None,
289348
base_url: str | httpx.URL | None = None,
290349
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
291350
http_client: httpx.AsyncClient | None = None,
@@ -321,6 +380,7 @@ def copy(
321380
return self.__class__(
322381
auth_token=auth_token or self.auth_token,
323382
base_url=base_url or self.base_url,
383+
environment=environment or self._environment,
324384
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
325385
http_client=http_client,
326386
max_retries=max_retries if is_given(max_retries) else self.max_retries,

tests/test_client.py

+20
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,16 @@ def test_base_url_env(self) -> None:
552552
client = LumaAI(auth_token=auth_token, _strict_response_validation=True)
553553
assert client.base_url == "http://localhost:5000/from/env/"
554554

555+
# explicit environment arg requires explicitness
556+
with update_env(LUMA_AI_BASE_URL="http://localhost:5000/from/env"):
557+
with pytest.raises(ValueError, match=r"you must pass base_url=None"):
558+
LumaAI(auth_token=auth_token, _strict_response_validation=True, environment="production")
559+
560+
client = LumaAI(
561+
base_url=None, auth_token=auth_token, _strict_response_validation=True, environment="production"
562+
)
563+
assert str(client.base_url).startswith("http://api.lumalabs.ai/dream-machine/v1alpha")
564+
555565
@pytest.mark.parametrize(
556566
"client",
557567
[
@@ -1265,6 +1275,16 @@ def test_base_url_env(self) -> None:
12651275
client = AsyncLumaAI(auth_token=auth_token, _strict_response_validation=True)
12661276
assert client.base_url == "http://localhost:5000/from/env/"
12671277

1278+
# explicit environment arg requires explicitness
1279+
with update_env(LUMA_AI_BASE_URL="http://localhost:5000/from/env"):
1280+
with pytest.raises(ValueError, match=r"you must pass base_url=None"):
1281+
AsyncLumaAI(auth_token=auth_token, _strict_response_validation=True, environment="production")
1282+
1283+
client = AsyncLumaAI(
1284+
base_url=None, auth_token=auth_token, _strict_response_validation=True, environment="production"
1285+
)
1286+
assert str(client.base_url).startswith("http://api.lumalabs.ai/dream-machine/v1alpha")
1287+
12681288
@pytest.mark.parametrize(
12691289
"client",
12701290
[

0 commit comments

Comments
 (0)