Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker/swirl/runner: Simplify AssertUnwindSafe usage #7453

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/cloudfront.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use aws_sdk_cloudfront::types::{InvalidationBatch, Paths};
use aws_sdk_cloudfront::{Client, Config};
use retry::delay::{jitter, Exponential};
use retry::OperationResult;
use std::panic::AssertUnwindSafe;
use std::time::Duration;
use tokio::runtime::Runtime;

pub struct CloudFront {
client: AssertUnwindSafe<Client>,
client: Client,
distribution_id: String,
}

Expand All @@ -27,7 +26,7 @@ impl CloudFront {
.credentials_provider(credentials)
.build();

let client = AssertUnwindSafe(Client::from_conf(config));
let client = Client::from_conf(config);

Some(Self {
client,
Expand Down
9 changes: 4 additions & 5 deletions src/worker/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ use crate::storage::Storage;
use crate::worker::swirl::PerformError;
use crates_io_index::Repository;
use reqwest::blocking::Client;
use std::panic::AssertUnwindSafe;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};

pub struct Environment {
index: Mutex<Repository>,
http_client: AssertUnwindSafe<Client>,
http_client: Client,
cloudfront: Option<CloudFront>,
fastly: Option<Fastly>,
pub storage: AssertUnwindSafe<Arc<Storage>>,
pub storage: Arc<Storage>,
}

impl Environment {
Expand All @@ -25,10 +24,10 @@ impl Environment {
) -> Self {
Self {
index: Mutex::new(index),
http_client: AssertUnwindSafe(http_client),
http_client,
cloudfront,
fastly,
storage: AssertUnwindSafe(storage),
storage,
}
}

Expand Down
34 changes: 15 additions & 19 deletions src/worker/swirl/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::error::Error;
use std::panic::{catch_unwind, AssertUnwindSafe, PanicInfo, UnwindSafe};
use std::panic::{catch_unwind, AssertUnwindSafe, PanicInfo};
use std::sync::mpsc::{sync_channel, SyncSender};
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -29,15 +29,15 @@ fn runnable<J: BackgroundJob>(
}

/// The core runner responsible for locking and running jobs
pub struct Runner<Context: Clone + Send + UnwindSafe + 'static> {
pub struct Runner<Context: Clone + Send + 'static> {
connection_pool: DieselPool,
thread_pool: ThreadPool,
job_registry: Arc<RwLock<HashMap<String, RunTaskFn<Context>>>>,
environment: Context,
job_start_timeout: Duration,
}

impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
impl<Context: Clone + Send + 'static> Runner<Context> {
pub fn new(connection_pool: DieselPool, environment: Context) -> Self {
Self {
connection_pool,
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
fn run_single_job(&self, sender: SyncSender<Event>) {
use diesel::result::Error::RollbackTransaction;

let job_registry = AssertUnwindSafe(self.job_registry.clone());
let job_registry = self.job_registry.clone();
let environment = self.environment.clone();

// The connection may not be `Send` so we need to clone the pool instead
Expand Down Expand Up @@ -155,11 +155,8 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
|| {
conn.transaction(|conn| {
let pool = pool.to_real_pool();
let state = AssertUnwindSafe(PerformState { conn, pool });
catch_unwind(|| {
// Ensure the whole `AssertUnwindSafe(_)` is moved
let state = state;

let state = PerformState { conn, pool };
catch_unwind(AssertUnwindSafe(|| {
let job_registry = job_registry.read();
let run_task_fn =
job_registry.get(&job.job_type).ok_or_else(|| {
Expand All @@ -169,8 +166,8 @@ impl<Context: Clone + Send + UnwindSafe + 'static> Runner<Context> {
))
})?;

run_task_fn(environment, state.0, job.data)
})
run_task_fn(environment, state, job.data)
}))
.map_err(|e| try_to_extract_panic_info(&e))
})
// TODO: Replace with flatten() once that stabilizes
Expand Down Expand Up @@ -294,7 +291,6 @@ mod tests {
use crates_io_test_db::TestDatabase;
use diesel::r2d2;
use diesel::r2d2::ConnectionManager;
use std::panic::AssertUnwindSafe;
use std::sync::{Arc, Barrier};

fn job_exists(id: i64, conn: &mut PgConnection) -> bool {
Expand Down Expand Up @@ -323,8 +319,8 @@ mod tests {
fn jobs_are_locked_when_fetched() {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<AssertUnwindSafe<Barrier>>,
assertions_finished_barrier: Arc<AssertUnwindSafe<Barrier>>,
job_started_barrier: Arc<Barrier>,
assertions_finished_barrier: Arc<Barrier>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -344,8 +340,8 @@ mod tests {
let test_database = TestDatabase::new();

let test_context = TestContext {
job_started_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
assertions_finished_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
job_started_barrier: Arc::new(Barrier::new(2)),
assertions_finished_barrier: Arc::new(Barrier::new(2)),
};

let runner =
Expand Down Expand Up @@ -409,7 +405,7 @@ mod tests {
fn failed_jobs_do_not_release_lock_before_updating_retry_time() {
#[derive(Clone)]
struct TestContext {
job_started_barrier: Arc<AssertUnwindSafe<Barrier>>,
job_started_barrier: Arc<Barrier>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -428,7 +424,7 @@ mod tests {
let test_database = TestDatabase::new();

let test_context = TestContext {
job_started_barrier: Arc::new(AssertUnwindSafe(Barrier::new(2))),
job_started_barrier: Arc::new(Barrier::new(2)),
};

let runner =
Expand Down Expand Up @@ -495,7 +491,7 @@ mod tests {
assert_eq!(tries, 1);
}

fn runner<Context: Clone + Send + UnwindSafe + 'static>(
fn runner<Context: Clone + Send + 'static>(
database_url: &str,
context: Context,
) -> Runner<Context> {
Expand Down