3
3
from __future__ import annotations
4
4
5
5
import os
6
- from typing import Any , Union , Mapping
7
- from typing_extensions import Self , override
6
+ from typing import Any , Dict , Union , Mapping , cast
7
+ from typing_extensions import Self , Literal , override
8
8
9
9
import httpx
10
10
33
33
)
34
34
35
35
__all__ = [
36
+ "ENVIRONMENTS" ,
36
37
"Timeout" ,
37
38
"Transport" ,
38
39
"ProxiesTypes" ,
44
45
"AsyncClient" ,
45
46
]
46
47
48
+ ENVIRONMENTS : Dict [str , str ] = {
49
+ "production" : "http://api.lumalabs.ai/dream-machine/v1alpha" ,
50
+ "production_api" : "http://internal-api.virginia.labs.lumalabs.ai/dream-machine/v1alpha" ,
51
+ "staging" : "http://internal-api.sandbox.labs.lumalabs.ai/dream-machine/v1alpha" ,
52
+ "localhost" : "http://localhost:9600/dream-machine/v1alpha" ,
53
+ }
54
+
47
55
48
56
class LumaAI (SyncAPIClient ):
49
57
ping : resources .PingResource
@@ -54,11 +62,14 @@ class LumaAI(SyncAPIClient):
54
62
# client options
55
63
auth_token : str
56
64
65
+ _environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | NotGiven
66
+
57
67
def __init__ (
58
68
self ,
59
69
* ,
60
70
auth_token : str ,
61
- base_url : str | httpx .URL | None = None ,
71
+ environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | NotGiven = NOT_GIVEN ,
72
+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
62
73
timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
63
74
max_retries : int = DEFAULT_MAX_RETRIES ,
64
75
default_headers : Mapping [str , str ] | None = None ,
@@ -80,10 +91,31 @@ def __init__(
80
91
"""Construct a new synchronous luma_ai client instance."""
81
92
self .auth_token = auth_token
82
93
83
- if base_url is None :
84
- base_url = os .environ .get ("LUMA_AI_BASE_URL" )
85
- if base_url is None :
86
- base_url = f"http://api.lumalabs.ai/dream-machine/v1alpha"
94
+ self ._environment = environment
95
+
96
+ base_url_env = os .environ .get ("LUMA_AI_BASE_URL" )
97
+ if is_given (base_url ) and base_url is not None :
98
+ # cast required because mypy doesn't understand the type narrowing
99
+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
100
+ elif is_given (environment ):
101
+ if base_url_env and base_url is not None :
102
+ raise ValueError (
103
+ "Ambiguous URL; The `LUMA_AI_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
104
+ )
105
+
106
+ try :
107
+ base_url = ENVIRONMENTS [environment ]
108
+ except KeyError as exc :
109
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
110
+ elif base_url_env is not None :
111
+ base_url = base_url_env
112
+ else :
113
+ self ._environment = environment = "production"
114
+
115
+ try :
116
+ base_url = ENVIRONMENTS [environment ]
117
+ except KeyError as exc :
118
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
87
119
88
120
super ().__init__ (
89
121
version = __version__ ,
@@ -125,6 +157,7 @@ def copy(
125
157
self ,
126
158
* ,
127
159
auth_token : str | None = None ,
160
+ environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | None = None ,
128
161
base_url : str | httpx .URL | None = None ,
129
162
timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
130
163
http_client : httpx .Client | None = None ,
@@ -160,6 +193,7 @@ def copy(
160
193
return self .__class__ (
161
194
auth_token = auth_token or self .auth_token ,
162
195
base_url = base_url or self .base_url ,
196
+ environment = environment or self ._environment ,
163
197
timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
164
198
http_client = http_client ,
165
199
max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
@@ -215,11 +249,14 @@ class AsyncLumaAI(AsyncAPIClient):
215
249
# client options
216
250
auth_token : str
217
251
252
+ _environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | NotGiven
253
+
218
254
def __init__ (
219
255
self ,
220
256
* ,
221
257
auth_token : str ,
222
- base_url : str | httpx .URL | None = None ,
258
+ environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | NotGiven = NOT_GIVEN ,
259
+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
223
260
timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
224
261
max_retries : int = DEFAULT_MAX_RETRIES ,
225
262
default_headers : Mapping [str , str ] | None = None ,
@@ -241,10 +278,31 @@ def __init__(
241
278
"""Construct a new async luma_ai client instance."""
242
279
self .auth_token = auth_token
243
280
244
- if base_url is None :
245
- base_url = os .environ .get ("LUMA_AI_BASE_URL" )
246
- if base_url is None :
247
- base_url = f"http://api.lumalabs.ai/dream-machine/v1alpha"
281
+ self ._environment = environment
282
+
283
+ base_url_env = os .environ .get ("LUMA_AI_BASE_URL" )
284
+ if is_given (base_url ) and base_url is not None :
285
+ # cast required because mypy doesn't understand the type narrowing
286
+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
287
+ elif is_given (environment ):
288
+ if base_url_env and base_url is not None :
289
+ raise ValueError (
290
+ "Ambiguous URL; The `LUMA_AI_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
291
+ )
292
+
293
+ try :
294
+ base_url = ENVIRONMENTS [environment ]
295
+ except KeyError as exc :
296
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
297
+ elif base_url_env is not None :
298
+ base_url = base_url_env
299
+ else :
300
+ self ._environment = environment = "production"
301
+
302
+ try :
303
+ base_url = ENVIRONMENTS [environment ]
304
+ except KeyError as exc :
305
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
248
306
249
307
super ().__init__ (
250
308
version = __version__ ,
@@ -286,6 +344,7 @@ def copy(
286
344
self ,
287
345
* ,
288
346
auth_token : str | None = None ,
347
+ environment : Literal ["production" , "production_api" , "staging" , "localhost" ] | None = None ,
289
348
base_url : str | httpx .URL | None = None ,
290
349
timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
291
350
http_client : httpx .AsyncClient | None = None ,
@@ -321,6 +380,7 @@ def copy(
321
380
return self .__class__ (
322
381
auth_token = auth_token or self .auth_token ,
323
382
base_url = base_url or self .base_url ,
383
+ environment = environment or self ._environment ,
324
384
timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
325
385
http_client = http_client ,
326
386
max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
0 commit comments