Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PP-2233] Prevent deadlocks in axis.import_identifiers and axis.reap_… #2330

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 51 additions & 42 deletions src/palace/manager/celery/tasks/axis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
from datetime import datetime
from datetime import datetime, timedelta

from celery import shared_task
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError, StaleDataError

Expand All @@ -25,14 +26,16 @@
from palace.manager.util.backoff import exponential_backoff
from palace.manager.util.datetime_helpers import datetime_utc, utc_now

DEFAULT_BATCH_SIZE: int = 25
DEFAULT_BATCH_SIZE: int = 10
DEFAULT_START_TIME = datetime_utc(1970, 1, 1)
TARGET_MAX_EXECUTION_SECONDS = 120


@shared_task(queue=QueueNames.default, bind=True)
def import_all_collections(
task: Task, import_all: bool = False, batch_size: int = DEFAULT_BATCH_SIZE
task: Task,
import_all: bool = False,
batch_size: int = DEFAULT_BATCH_SIZE,
sub_task_execution_interval_in_secs: int = 10,
) -> None:
"""
A shared task that loops through all Axis360 Api based collections and kick off an
Expand All @@ -43,11 +46,19 @@ def import_all_collections(
for collection in get_collections_by_protocol(
task=task, session=session, protocol_class=Axis360API
):
# since collections will likely have overlapping identifiers, minimize the possibility
# of deadlocks by staggering the execution of the list_identifiers_for_import task
# by sub_task_execution_interval_in_secs.
eta = utc_now() + timedelta(
seconds=(sub_task_execution_interval_in_secs * count)
)

task.log.info(
f'Queued collection("{collection.name}" [id={collection.id}] for importing...'
f'Queued collection("{collection.name}" [id={collection.id}] for importing to begin execution at {eta}...'
)
list_identifiers_for_import.apply_async(
kwargs={"collection_id": collection.id},
eta=eta,
link=import_identifiers.s(
collection_id=collection.id,
batch_size=batch_size,
Expand Down Expand Up @@ -177,7 +188,6 @@ def import_identifiers(
collection_id: int,
processed_count: int = 0,
batch_size: int = DEFAULT_BATCH_SIZE,
target_max_execution_time_in_seconds: float = TARGET_MAX_EXECUTION_SECONDS,
) -> None:
"""
This method creates new or updates new editions and license pools for each identifier in the list of identifiers.
Expand All @@ -197,7 +207,7 @@ def import_identifiers(
def log_run_end_message() -> None:
task.log.info(
f"Finished importing identifiers for collection ({collection_name}, id={collection_id}), "
f"task(id={task.request.id})"
f"processed_count={processed_count}, task(id={task.request.id})"
)

if identifiers is None:
Expand All @@ -214,53 +224,34 @@ def log_run_end_message() -> None:
)
log_run_end_message()
return

api = create_api(session=session, collection=collection)
start_seconds = time.perf_counter()
total_imported_in_current_task = 0
while len(identifiers) > 0:
batch = identifiers[:batch_size]

batch = identifiers[:batch_size]
batch_length = len(batch)
if batch_length > 0:
try:
for metadata, circulation in api.availability_by_title_ids(
title_ids=batch
):
process_book(task, session, api, metadata, circulation)
except (ObjectDeletedError, StaleDataError) as e:
except (ObjectDeletedError, StaleDataError, OperationalError) as e:
wait_time = exponential_backoff(task.request.retries)
task.log.exception(
f"Something unexpected went wrong while processing a batch of titles for collection "
f'"{collection_name}" task(id={task.request.id} due to {e}. Retrying in {wait_time} seconds.'
)
raise task.retry(countdown=wait_time)

batch_length = len(batch)
task.log.info(
f"Imported {batch_length} identifiers for collection ({collection_name}, id={collection_id})"
)
total_imported_in_current_task += batch_length
task.log.info(
f"Total imported {total_imported_in_current_task} identifiers in current task for collection ({collection_name}, id={collection_id})"
)

# remove identifiers processed in previous batch
identifiers = identifiers[len(batch) :]
identifiers_list_length = len(identifiers)
# measure elapsed seconds
elapsed_seconds = time.perf_counter() - start_seconds

if elapsed_seconds > target_max_execution_time_in_seconds:
task.log.info(
f"Execution time exceeded max allowable seconds (max={target_max_execution_time_in_seconds}): "
f"elapsed seconds={elapsed_seconds}"
)
break
# measure elapsed seconds
elapsed_seconds = time.perf_counter() - start_seconds

processed_count += total_imported_in_current_task
# remove identifiers processed in previous batch
identifiers = identifiers[len(batch) :]
processed_count += batch_length

task.log.info(
f"Imported {processed_count} identifiers in run for collection ({collection_name}, id={collection_id})"
)
task.log.info(
f'Batch of {batch_length} identifiers for collection (name="{collection_name}", id={collection_id}) imported in {elapsed_seconds} seconds.'
)

if len(identifiers) > 0:
task.log.info(
Expand Down Expand Up @@ -325,23 +316,41 @@ def get_collections_by_protocol(


@shared_task(queue=QueueNames.default, bind=True)
def reap_all_collections(task: Task) -> None:
def reap_all_collections(
task: Task, sub_task_execution_interval_in_secs: int = 10
) -> None:
"""
A shared task that kicks off a reap collection task for each Axis 360 collection.
"""

count = 0
with task.session() as session:
for collection in get_collections_by_protocol(task, session, Axis360API):

# since collections will likely have overlapping identifiers, minimize the possibility
# of deadlocks by staggering the execution of the list_identifiers_for_import task
# by sub_task_execution_interval_in_secs.
eta = utc_now() + timedelta(
seconds=(sub_task_execution_interval_in_secs * count)
)

task.log.info(
f'Queued collection("{collection.name}" [id={collection.id}] for reaping...'
f'Queued collection("{collection.name}" [id={collection.id}] for reaping at {eta}.'
)
reap_collection.delay(collection_id=collection.id)
reap_collection.apply_async(
kwargs={"collection_id": collection.id}, eta=eta
)
count += 1

task.log.info(f"Finished queuing reap collection tasks.")


@shared_task(queue=QueueNames.default, bind=True)
def reap_collection(
task: Task, collection_id: int, offset: int = 0, batch_size: int = 25
task: Task,
collection_id: int,
offset: int = 0,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> None:
"""
Update the license pools associated with a subset of identifiers in a collection
Expand Down
26 changes: 18 additions & 8 deletions tests/manager/celery/tasks/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def test_import_all_collections(
with patch.object(
axis, "list_identifiers_for_import"
) as mock_list_identifiers_for_import:
import_all_collections.delay().wait()
# turn off subtask execution interval.
import_all_collections.apply_async(
kwargs={"sub_task_execution_interval_in_secs": 0}
).wait()

assert mock_list_identifiers_for_import.apply_async.call_count == 1
assert (
Expand Down Expand Up @@ -217,7 +220,10 @@ def test_import_items(
f"Finished importing identifiers for collection ({collection.name}, id={collection.id})"
in caplog.text
)
assert f"Imported {len(title_ids)} identifiers" in caplog.text
assert (
f'Batch of 2 identifiers for collection (name="{collection.name}", '
f"id={collection.id}) imported"
) in caplog.text


def test_import_identifiers_with_requeue(
Expand Down Expand Up @@ -248,7 +254,6 @@ def test_import_identifiers_with_requeue(
identifiers=identifiers,
collection_id=collection.id,
batch_size=1,
target_max_execution_time_in_seconds=0,
).wait()

assert mock_api.availability_by_title_ids.call_count == 2
Expand Down Expand Up @@ -277,12 +282,17 @@ def test_reap_all_collections(
db.default_collection()
collection2 = db.collection(name="test_collection", protocol=Axis360API.label())
with patch.object(axis, "reap_collection") as mock_reap_collection:
reap_all_collections.delay().wait()
reap_all_collections.apply_async(
kwargs={"sub_task_execution_interval_in_secs": 0}
).wait()

assert mock_reap_collection.delay.call_count == 1
assert mock_reap_collection.delay.call_args_list[0].kwargs == {
"collection_id": collection2.id,
}
assert mock_reap_collection.apply_async.call_count == 1
assert (
mock_reap_collection.apply_async.call_args_list[0].kwargs["kwargs"][
"collection_id"
]
== collection2.id
)
assert "Finished queuing reap collection tasks" in caplog.text


Expand Down