Skip to content

Commit

Permalink
refactor(oma-refresh)!: move download release to MirrorSources impl
Browse files Browse the repository at this point in the history
  • Loading branch information
eatradish committed Jan 17, 2025
1 parent dd43d69 commit d28c33e
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 321 deletions.
19 changes: 18 additions & 1 deletion oma-fetch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use checksum::Checksum;
use download::{EmptySource, SingleDownloader, SuccessSummary};
use futures::{Future, StreamExt};

use reqwest::Client;
use reqwest::{Client, Method, RequestBuilder};
use tracing::debug;

pub mod checksum;
mod download;
Expand Down Expand Up @@ -272,3 +273,19 @@ impl DownloadManager<'_> {
Ok(Summary { success, failed })
}
}

pub fn build_request_with_basic_auth(
client: &Client,
method: Method,
auth: &Option<(String, String)>,
url: &str,
) -> RequestBuilder {
let mut req = client.request(method, url);

if let Some((user, password)) = auth {
debug!("auth user: {}", user);
req = req.basic_auth(user, Some(password));
}

req
}
313 changes: 10 additions & 303 deletions oma-refresh/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use aho_corasick::BuildError;
use apt_auth_config::AuthConfig;
use bon::{builder, Builder};
use chrono::Utc;
use futures::StreamExt;
use nix::{
errno::Errno,
fcntl::{
Expand Down Expand Up @@ -45,8 +44,7 @@ use reqwest::StatusCode;

use sysinfo::{Pid, System};
use tokio::{
fs::{self, File},
io::AsyncWriteExt,
fs::{self},
process::Command,
task::spawn_blocking,
};
Expand Down Expand Up @@ -352,13 +350,14 @@ impl<'a> OmaRefresh<'a> {
let mut mirror_sources =
MirrorSources::from_sourcelist(sourcelist, replacer, self.auth_config)?;

let tasks = mirror_sources.0.iter().enumerate().map(|(index, m)| {
self.get_release_file(m, replacer, index, mirror_sources.0.len(), callback)
});

let results = futures::stream::iter(tasks)
.buffer_unordered(self.threads)
.collect::<Vec<_>>()
let results = mirror_sources
.fetch_all_release(
&self.client,

Check warning on line 355 in oma-refresh/src/db.rs

View workflow job for this annotation

GitHub Actions / clippy

this expression creates a reference which is immediately dereferenced by the compiler

warning: this expression creates a reference which is immediately dereferenced by the compiler --> oma-refresh/src/db.rs:355:17 | 355 | &self.client, | ^^^^^^^^^^^^ help: change this to: `self.client` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_borrow = note: `#[warn(clippy::needless_borrow)]` on by default
replacer,
&self.download_dir,
self.threads,
callback,
)
.await;

debug!("download_releases: results: {:?}", results);
Expand Down Expand Up @@ -464,298 +463,6 @@ impl<'a> OmaRefresh<'a> {
Ok(())
}

async fn get_release_file<'b, F, Fut>(
&self,
entry: &MirrorSource<'b, 'a>,
replacer: &DatabaseFilenameReplacer,
progress_index: usize,
total: usize,
callback: &F,
) -> Result<()>
where
F: Fn(Event) -> Fut,
Fut: Future<Output = ()>,
{
match entry.from()? {
OmaSourceEntryFrom::Http => {
self.download_http_release(entry, replacer, progress_index, total, callback)
.await
}
OmaSourceEntryFrom::Local => {
self.download_local_release(entry, replacer, progress_index, total, callback)
.await
}
}
}

async fn download_local_release<'b, F, Fut>(
&self,
entry: &MirrorSource<'b, 'a>,
replacer: &DatabaseFilenameReplacer,
index: usize,
total: usize,
callback: &F,
) -> Result<()>
where
F: Fn(Event) -> Fut,
Fut: Future<Output = ()>,
{
let dist_path_with_protocol = entry.dist_path();
let dist_path = dist_path_with_protocol
.strip_prefix("file:")
.unwrap_or(dist_path_with_protocol);
let dist_path = Path::new(dist_path);

let mut name = None;

let msg = entry.get_human_download_url(None)?;

callback(Event::DownloadEvent(oma_fetch::Event::NewProgressSpinner {
index,
msg: format!("({}/{}) {}", index, total, msg),
}))
.await;

let mut is_release = false;

for (index, entry) in ["InRelease", "Release"].iter().enumerate() {
let p = dist_path.join(entry);

let dst = if dist_path_with_protocol.ends_with('/') {
format!("{}{}", dist_path_with_protocol, entry)
} else {
format!("{}/{}", dist_path_with_protocol, entry)
};

let file_name = replacer.replace(&dst)?;

let dst = self.download_dir.join(&file_name);

if p.exists() {
if dst.exists() {
debug!("get_release_file: Removing {}", dst.display());
fs::remove_file(&dst)
.await
.map_err(|e| RefreshError::OperateFile(dst.clone(), e))?;
}

debug!("get_release_file: Symlink {}", dst.display());
fs::symlink(p, &dst)
.await
.map_err(|e| RefreshError::OperateFile(dst.clone(), e))?;

if index == 1 {
is_release = true;
}

name = Some(file_name);
break;
}
}

if name.is_none() && entry.is_flat() {
// Flat repo no release
return Ok(());
}

if is_release {
let p = dist_path.join("Release.gpg");
let entry = "Release.gpg";

let dst = if dist_path_with_protocol.ends_with('/') {
format!("{}{}", dist_path_with_protocol, entry)
} else {
format!("{}/{}", dist_path_with_protocol, entry)
};

let file_name = replacer.replace(&dst)?;

let dst = self.download_dir.join(&file_name);

if p.exists() {
if dst.exists() {
fs::remove_file(&dst)
.await
.map_err(|e| RefreshError::OperateFile(dst.clone(), e))?;
}

fs::symlink(p, self.download_dir.join(file_name))
.await
.map_err(|e| RefreshError::OperateFile(dst.clone(), e))?;
}
}

callback(Event::DownloadEvent(oma_fetch::Event::ProgressDone(index))).await;

let name = name.ok_or_else(|| RefreshError::NoInReleaseFile(entry.url().to_string()))?;
entry.set_release_file_name(name);

Ok(())
}

async fn download_http_release<'b, F, Fut>(
&self,
entry: &MirrorSource<'b, 'a>,
replacer: &DatabaseFilenameReplacer,
index: usize,
total: usize,
callback: &F,
) -> std::result::Result<(), RefreshError>
where
F: Fn(Event) -> Fut,
Fut: Future<Output = ()>,
{
let dist_path = entry.dist_path();

let mut r = None;
let mut u = None;
let mut is_release = false;

let msg = entry.get_human_download_url(None)?;

callback(Event::DownloadEvent(oma_fetch::Event::NewProgressSpinner {
index,
msg: format!("({}/{}) {}", index, total, msg),
}))
.await;

for (index, file_name) in ["InRelease", "Release"].iter().enumerate() {
let url = format!("{}/{}", dist_path, file_name);
let request = self.request_get_builder(&url, entry);

let resp = request
.send()
.await
.and_then(|resp| resp.error_for_status());

r = Some(resp);

if r.as_ref().unwrap().is_ok() {
u = Some(url);
if index == 1 {
is_release = true;
}
break;
}
}

let r = r.unwrap();

callback(Event::DownloadEvent(oma_fetch::Event::ProgressDone(index))).await;

if r.is_err() && entry.is_flat() {
// Flat repo no release
return Ok(());
}

let resp = r
.map_err(|e| SingleDownloadError::ReqwestError { source: e })
.map_err(|e| RefreshError::DownloadFailed(Some(e)))?;

let url = u.unwrap();
let file_name = replacer.replace(&url)?;

self.download_file(&file_name, resp, entry, index, total, &callback)
.await
.map_err(|e| RefreshError::DownloadFailed(Some(e)))?;

entry.set_release_file_name(file_name);

if is_release && !entry.trusted() {
let url = format!("{}/{}", dist_path, "Release.gpg");

let request = self.request_get_builder(&url, entry);
let resp = request
.send()
.await
.and_then(|resp| resp.error_for_status())
.map_err(|e| SingleDownloadError::ReqwestError { source: e })
.map_err(|e| RefreshError::DownloadFailed(Some(e)))?;

let file_name = replacer.replace(&url)?;

self.download_file(&file_name, resp, entry, index, total, &callback)
.await
.map_err(|e| RefreshError::DownloadFailed(Some(e)))?;
}

Ok(())
}

fn request_get_builder<'b>(
&self,
url: &str,
source_index: &MirrorSource<'b, 'a>,
) -> reqwest::RequestBuilder {
let mut request = self.client.get(url);
if let Some(auth) = source_index.auth() {
request = request.basic_auth(&auth.login, Some(&auth.password))
}

request
}

async fn download_file<'b, F, Fut>(
&self,
file_name: &str,
mut resp: Response,
source_index: &MirrorSource<'b, 'a>,
index: usize,
total: usize,
callback: &F,
) -> std::result::Result<(), SingleDownloadError>
where
F: Fn(Event) -> Fut,
Fut: Future<Output = ()>,
{
let total_size = content_length(&resp);

callback(Event::DownloadEvent(oma_fetch::Event::NewProgressBar {
index,
msg: format!(
"({}/{}) {}",
index,
total,
source_index
.get_human_download_url(Some(file_name))
.unwrap(),
),
size: total_size,
}));

let mut f = File::create(self.download_dir.join(file_name))
.await
.map_err(|e| SingleDownloadError::Create { source: e })?;

f.set_permissions(Permissions::from_mode(0o644))
.await
.map_err(|e| SingleDownloadError::SetPermission { source: e })?;

while let Some(chunk) = resp
.chunk()
.await
.map_err(|e| SingleDownloadError::ReqwestError { source: e })?
{
callback(Event::DownloadEvent(oma_fetch::Event::ProgressInc {
index,
size: chunk.len() as u64,
}))
.await;

f.write_all(&chunk)
.await
.map_err(|e| SingleDownloadError::Write { source: e })?;
}

f.shutdown()
.await
.map_err(|e| SingleDownloadError::Flush { source: e })?;

callback(Event::DownloadEvent(oma_fetch::Event::ProgressDone(index))).await;

Ok(())
}

async fn collect_all_release_entry<'b>(
&self,
replacer: &DatabaseFilenameReplacer,
Expand Down Expand Up @@ -856,7 +563,7 @@ impl<'a> OmaRefresh<'a> {
}
}

fn content_length(resp: &Response) -> u64 {
pub fn content_length(resp: &Response) -> u64 {
let content_length = resp
.headers()
.get(CONTENT_LENGTH)
Expand Down
Loading

0 comments on commit d28c33e

Please sign in to comment.