Skip to content

Commit 68841e8

Browse files
authored
Remove retry_policy from profiles (#2406)
1 parent 6d7abc9 commit 68841e8

File tree

9 files changed

+30
-52
lines changed

9 files changed

+30
-52
lines changed

src/dstack/_internal/cli/services/profile.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dstack._internal.core.models.profiles import (
77
CreationPolicy,
88
Profile,
9-
ProfileRetryPolicy,
9+
ProfileRetry,
1010
SpotPolicy,
1111
TerminationPolicy,
1212
parse_duration,
@@ -120,10 +120,8 @@ def register_profile_args(parser: argparse.ArgumentParser):
120120

121121
retry_group = parser.add_argument_group("Retry policy")
122122
retry_group_exc = retry_group.add_mutually_exclusive_group()
123-
retry_group_exc.add_argument("--retry", action="store_const", dest="retry_policy", const=True)
124-
retry_group_exc.add_argument(
125-
"--no-retry", action="store_const", dest="retry_policy", const=False
126-
)
123+
retry_group_exc.add_argument("--retry", action="store_const", dest="retry", const=True)
124+
retry_group_exc.add_argument("--no-retry", action="store_const", dest="retry", const=False)
127125
retry_group_exc.add_argument(
128126
"--retry-duration", type=retry_duration, dest="retry_duration", metavar="DURATION"
129127
)
@@ -161,15 +159,12 @@ def apply_profile_args(
161159
if args.spot_policy is not None:
162160
profile_settings.spot_policy = args.spot_policy
163161

164-
if args.retry_policy is not None:
165-
if not profile_settings.retry_policy:
166-
profile_settings.retry_policy = ProfileRetryPolicy()
167-
profile_settings.retry_policy.retry = args.retry_policy
162+
if args.retry is not None:
163+
profile_settings.retry = args.retry
168164
elif args.retry_duration is not None:
169-
if not profile_settings.retry_policy:
170-
profile_settings.retry_policy = ProfileRetryPolicy()
171-
profile_settings.retry_policy.retry = True
172-
profile_settings.retry_policy.duration = args.retry_duration
165+
profile_settings.retry = ProfileRetry(
166+
duration=args.retry_duration,
167+
)
173168

174169

175170
def max_duration(v: str) -> int:

src/dstack/_internal/core/models/profiles.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[st
6969
return parse_duration(v)
7070

7171

72+
# Deprecated in favor of ProfileRetry().
73+
# TODO: Remove when no longer referenced.
7274
class ProfileRetryPolicy(CoreModel):
7375
retry: Annotated[bool, Field(description="Whether to retry the run on failure or not")] = False
7476
duration: Annotated[
@@ -95,14 +97,15 @@ class RetryEvent(str, Enum):
9597

9698
class ProfileRetry(CoreModel):
9799
on_events: Annotated[
98-
List[RetryEvent],
100+
Optional[List[RetryEvent]],
99101
Field(
100102
description=(
101103
"The list of events that should be handled with retry."
102-
" Supported events are `no-capacity`, `interruption`, and `error`"
104+
" Supported events are `no-capacity`, `interruption`, and `error`."
105+
" Omit to retry on all events"
103106
)
104107
),
105-
]
108+
] = None
106109
duration: Annotated[
107110
Optional[Union[int, str]],
108111
Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"),
@@ -112,7 +115,8 @@ class ProfileRetry(CoreModel):
112115

113116
@root_validator
114117
def _validate_fields(cls, values):
115-
if "on_events" in values and len(values["on_events"]) == 0:
118+
on_events = values.get("on_events", None)
119+
if on_events is not None and len(values["on_events"]) == 0:
116120
raise ValueError("`on_events` cannot be empty")
117121
return values
118122

@@ -249,8 +253,6 @@ class ProfileParams(CoreModel):
249253
description="Deprecated in favor of `idle_duration`",
250254
),
251255
] = None
252-
# The policy for resubmitting the run. Deprecated in favor of `retry`
253-
retry_policy: Optional[ProfileRetryPolicy] = None
254256

255257
_validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)(
256258
parse_max_duration

src/dstack/_internal/core/models/runs.py

-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
CreationPolicy,
2121
Profile,
2222
ProfileParams,
23-
ProfileRetryPolicy,
2423
RetryEvent,
2524
SpotPolicy,
2625
UtilizationPolicy,
@@ -204,9 +203,6 @@ class JobSpec(CoreModel):
204203
retry: Optional[Retry]
205204
volumes: Optional[List[MountPoint]] = None
206205
ssh_key: Optional[JobSSHKey] = None
207-
# For backward compatibility with 0.18.x when retry_policy was required.
208-
# TODO: remove in 0.19
209-
retry_policy: ProfileRetryPolicy = ProfileRetryPolicy(retry=False)
210206
working_dir: Optional[str]
211207

212208

src/dstack/_internal/core/services/profiles.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,7 @@
1212
def get_retry(profile: Profile) -> Optional[Retry]:
1313
profile_retry = profile.retry
1414
if profile_retry is None:
15-
# Handle retry_policy before retry was introduced
16-
# TODO: Remove once retry_policy no longer supported
17-
profile_retry_policy = profile.retry_policy
18-
if profile_retry_policy is None:
19-
return None
20-
if not profile_retry_policy.retry:
21-
return None
22-
duration = profile_retry_policy.duration or DEFAULT_RETRY_DURATION
23-
return Retry(
24-
on_events=[RetryEvent.NO_CAPACITY, RetryEvent.INTERRUPTION, RetryEvent.ERROR],
25-
duration=duration,
26-
)
15+
return None
2716
if isinstance(profile_retry, bool):
2817
if profile_retry:
2918
return Retry(
@@ -32,6 +21,12 @@ def get_retry(profile: Profile) -> Optional[Retry]:
3221
)
3322
return None
3423
profile_retry = profile_retry.copy()
24+
if profile_retry.on_events is None:
25+
profile_retry.on_events = [
26+
RetryEvent.NO_CAPACITY,
27+
RetryEvent.INTERRUPTION,
28+
RetryEvent.ERROR,
29+
]
3530
if profile_retry.duration is None:
3631
profile_retry.duration = DEFAULT_RETRY_DURATION
3732
return Retry.parse_obj(profile_retry)

src/dstack/api/_public/runs.py

-1
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ def get_plan(
640640
reservation=reservation,
641641
spot_policy=spot_policy,
642642
retry=None,
643-
retry_policy=retry_policy,
644643
utilization_policy=utilization_policy,
645644
max_duration=max_duration,
646645
stop_duration=stop_duration,

src/tests/_internal/cli/services/configurators/test_profile.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
apply_profile_args,
66
register_profile_args,
77
)
8-
from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy, SpotPolicy
8+
from dstack._internal.core.models.profiles import Profile, ProfileRetry, SpotPolicy
99

1010

1111
class TestProfileArgs:
@@ -51,21 +51,21 @@ def test_spot_policy_on_demand(self):
5151
assert profile.dict() == modified.dict()
5252

5353
def test_retry(self):
54-
profile = Profile(name="test")
55-
profile.retry_policy = ProfileRetryPolicy(retry=True)
54+
profile = Profile(name="test", retry=None)
5655
modified, _ = apply_args(profile, ["--retry"])
56+
profile.retry = True
5757
assert profile.dict() == modified.dict()
5858

5959
def test_no_retry(self):
60-
profile = Profile(name="test", retry_policy=ProfileRetryPolicy(retry=True, duration=3600))
60+
profile = Profile(name="test", retry=None)
6161
modified, _ = apply_args(profile, ["--no-retry"])
62-
profile.retry_policy.retry = False
62+
profile.retry = False
6363
assert profile.dict() == modified.dict()
6464

6565
def test_retry_duration(self):
6666
profile = Profile(name="test")
6767
modified, _ = apply_args(profile, ["--retry-duration", "1h"])
68-
profile.retry_policy = ProfileRetryPolicy(retry=True, duration=3600)
68+
profile.retry = ProfileRetry(on_events=None, duration="1h")
6969
assert profile.dict() == modified.dict()
7070

7171

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
InstanceType,
1515
Resources,
1616
)
17-
from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy
17+
from dstack._internal.core.models.profiles import Profile
1818
from dstack._internal.core.models.runs import (
1919
JobProvisioningData,
2020
JobStatus,
@@ -372,7 +372,6 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
372372
repo_id=repo.name,
373373
profile=Profile(
374374
name="default",
375-
retry_policy=ProfileRetryPolicy(retry=True, duration=3600),
376375
),
377376
),
378377
)

src/tests/_internal/server/routers/test_fleets.py

-2
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
359359
"instance_types": None,
360360
"spot_policy": None,
361361
"retry": None,
362-
"retry_policy": None,
363362
"max_duration": None,
364363
"stop_duration": None,
365364
"max_price": None,
@@ -482,7 +481,6 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
482481
"instance_types": None,
483482
"spot_policy": None,
484483
"retry": None,
485-
"retry_policy": None,
486484
"max_duration": None,
487485
"stop_duration": None,
488486
"max_price": None,

src/tests/_internal/server/routers/test_runs.py

-6
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def get_dev_env_run_plan_dict(
120120
"stop_duration": None,
121121
"max_price": None,
122122
"retry": None,
123-
"retry_policy": None,
124123
"spot_policy": "spot",
125124
"idle_duration": None,
126125
"termination_idle_time": 300,
@@ -141,7 +140,6 @@ def get_dev_env_run_plan_dict(
141140
"max_price": None,
142141
"name": "string",
143142
"retry": None,
144-
"retry_policy": None,
145143
"spot_policy": "spot",
146144
"idle_duration": None,
147145
"termination_idle_time": 300,
@@ -206,7 +204,6 @@ def get_dev_env_run_plan_dict(
206204
"retry": None,
207205
"volumes": volumes,
208206
"ssh_key": None,
209-
"retry_policy": {"retry": False, "duration": None},
210207
"working_dir": ".",
211208
},
212209
"offers": [json.loads(o.json()) for o in offers],
@@ -277,7 +274,6 @@ def get_dev_env_run_dict(
277274
"stop_duration": None,
278275
"max_price": None,
279276
"retry": None,
280-
"retry_policy": None,
281277
"spot_policy": "spot",
282278
"idle_duration": None,
283279
"termination_idle_time": 300,
@@ -298,7 +294,6 @@ def get_dev_env_run_dict(
298294
"max_price": None,
299295
"name": "string",
300296
"retry": None,
301-
"retry_policy": None,
302297
"spot_policy": "spot",
303298
"idle_duration": None,
304299
"termination_idle_time": 300,
@@ -363,7 +358,6 @@ def get_dev_env_run_dict(
363358
"retry": None,
364359
"volumes": [],
365360
"ssh_key": None,
366-
"retry_policy": {"retry": False, "duration": None},
367361
"working_dir": ".",
368362
},
369363
"job_submissions": [

0 commit comments

Comments
 (0)