From 9209561cf72c48e7c47e81a3995b389a1be88121 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Fri, 26 Jan 2024 14:10:25 +0100 Subject: [PATCH] feat: use file extensions from collection names for mime guessing --- iroh-gateway/Cargo.lock | 20 ++++++++++ iroh-gateway/Cargo.toml | 1 + iroh-gateway/src/main.rs | 86 ++++++++++++++++++++++++---------------- 3 files changed, 73 insertions(+), 34 deletions(-) diff --git a/iroh-gateway/Cargo.lock b/iroh-gateway/Cargo.lock index 8a0afce..8893263 100644 --- a/iroh-gateway/Cargo.lock +++ b/iroh-gateway/Cargo.lock @@ -2069,6 +2069,7 @@ dependencies = [ "lru", "mime", "mime_classifier", + "mime_guess", "quinn", "range-collections", "rustls", @@ -2425,6 +2426,16 @@ dependencies = [ "serde", ] +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4761,6 +4772,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.14" diff --git a/iroh-gateway/Cargo.toml b/iroh-gateway/Cargo.toml index f93e194..8c56c41 100644 --- a/iroh-gateway/Cargo.toml +++ b/iroh-gateway/Cargo.toml @@ -30,3 +30,4 @@ tokio-rustls-acme = { version = "0.2.0", features = ["axum"] } hyper-util = "0.1.2" rustls-pemfile = "1.0.2" tower-service = "0.3.2" +mime_guess = "2.0.4" diff --git a/iroh-gateway/src/main.rs b/iroh-gateway/src/main.rs index c5fcede..c4b60ee 100644 --- a/iroh-gateway/src/main.rs +++ b/iroh-gateway/src/main.rs @@ -95,7 +95,7 @@ struct Inner { #[debug("MimeClassifier")] mime_classifier: MimeClassifier, /// Cache of hashes to mime types - mime_cache: Mutex>, + mime_cache: Mutex), (u64, Mime)>>, /// Cache of hashes to collections collection_cache: Mutex>, } @@ -175,25 +175,19 @@ async fn get_collection( return Ok(res.clone()); } let (collection, headers) = get_collection_inner(hash, connection, true).await?; - let mimes = headers - .into_iter() - .map(|(hash, size, header)| { - let mime = gateway.mime_classifier.classify( - mime_classifier::LoadContext::Browsing, - mime_classifier::NoSniffFlag::Off, - mime_classifier::ApacheBugFlag::On, - &None, - &header, - ); - (hash, size, mime) - }) - .collect::>(); - { - let mut cache = gateway.mime_cache.lock().unwrap(); - for (hash, size, mime) in mimes { - cache.put(hash, (size, mime)); - } + + let mut cache = gateway.mime_cache.lock().unwrap(); + for (name, hash) in collection.iter() { + let ext = get_extension(name); + let Some((hash, size, data)) = headers.iter().find(|(h, _, _)| h == hash) else { + tracing::debug!("hash {hash:?} for name {name:?} not found in headers"); + continue; + }; + let mime = get_mime_from_ext_and_data(ext.as_deref(), &data, &gateway.mime_classifier); + let key = (*hash, ext); + cache.put(key, (*size, mime)); } + drop(cache); gateway .collection_cache @@ -203,9 +197,16 @@ async fn get_collection( Ok(collection) } +fn get_extension(name: &str) -> Option { + std::path::Path::new(name) + .extension() + .map(|s| s.to_string_lossy().to_string()) +} + /// Get the mime type for a hash from the remote node. async fn get_mime_type_inner( hash: &Hash, + ext: Option<&str>, connection: &quinn::Connection, mime_classifier: &MimeClassifier, ) -> anyhow::Result<(u64, Mime)> { @@ -223,31 +224,46 @@ async fn get_mime_type_inner( anyhow::bail!("unexpected response"); }; let _stats = at_closing.next().await?; + let mime = get_mime_from_ext_and_data(ext, &data, mime_classifier); + Ok((size, mime)) +} + +fn get_mime_from_ext_and_data( + ext: Option<&str>, + data: &[u8], + mime_classifier: &MimeClassifier, +) -> Mime { let context = mime_classifier::LoadContext::Browsing; - let no_sniff_flag = mime_classifier::NoSniffFlag::Off; + let no_sniff_flag = mime_classifier::NoSniffFlag::On; let apache_bug_flag = mime_classifier::ApacheBugFlag::On; - let supplied_type = None; - let mime = mime_classifier.classify( + let supplied_type = match ext { + None => None, + Some(ext) => mime_guess::from_ext(ext).first(), + }; + mime_classifier.classify( context, no_sniff_flag, apache_bug_flag, &supplied_type, - &data, - ); - Ok((size, mime)) + data, + ) } /// Get the mime type for a hash, either from the cache or by requesting it from the node. async fn get_mime_type( gateway: &Gateway, hash: &Hash, + name: Option<&str>, connection: &quinn::Connection, ) -> anyhow::Result<(u64, Mime)> { - if let Some(sm) = gateway.mime_cache.lock().unwrap().get(hash) { + let ext = name.map(|n| get_extension(n)).flatten(); + let key = (*hash, ext.clone()); + if let Some(sm) = gateway.mime_cache.lock().unwrap().get(&key) { return Ok(sm.clone()); } - let sm = get_mime_type_inner(hash, connection, &gateway.mime_classifier).await?; - gateway.mime_cache.lock().unwrap().put(*hash, sm.clone()); + let sm = + get_mime_type_inner(hash, ext.as_deref(), connection, &gateway.mime_classifier).await?; + gateway.mime_cache.lock().unwrap().put(key, sm.clone()); Ok(sm) } @@ -259,7 +275,7 @@ async fn handle_local_blob_request( ) -> std::result::Result, AppError> { let connection = gateway.get_default_connection().await?; let byte_range = parse_byte_range(req).await?; - let res = forward_range(&gateway, connection, &blake3_hash, byte_range).await?; + let res = forward_range(&gateway, connection, &blake3_hash, None, byte_range).await?; Ok(res) } @@ -299,7 +315,7 @@ async fn handle_ticket_index( let hash = ticket.hash(); let prefix = format!("/ticket/{}", ticket); let res = match ticket.format() { - BlobFormat::Raw => forward_range(&gateway, connection, &hash, byte_range) + BlobFormat::Raw => forward_range(&gateway, connection, &hash, None, byte_range) .await? .into_response(), BlobFormat::HashSeq => collection_index(&gateway, connection, &hash, &prefix) @@ -345,7 +361,8 @@ async fn collection_index( for (name, child_hash) in collection.iter() { let url = format!("{}/{}", link_prefix, name); let url = encode_relative_url(&url)?; - let smo = gateway.mime_cache.lock().unwrap().get(child_hash).cloned(); + let key = (*child_hash, get_extension(name)); + let smo = gateway.mime_cache.lock().unwrap().get(&key).cloned(); res.push_str(&format!("{}", url, name,)); if let Some((size, mime)) = smo { res.push_str(&format!(" ({}, {})", mime, indicatif::HumanBytes(size))); @@ -373,7 +390,7 @@ async fn forward_collection_range( let collection = get_collection(gateway, hash, &connection).await?; for (name, hash) in collection.iter() { if name == suffix { - let res = forward_range(gateway, connection, hash, range).await?; + let res = forward_range(gateway, connection, hash, Some(suffix), range).await?; return Ok(res.into_response()); } else { tracing::trace!("'{}' != '{}'", name, suffix); @@ -400,16 +417,17 @@ async fn forward_range( gateway: &Gateway, connection: quinn::Connection, hash: &Hash, + name: Option<&str>, (start, end): (Option, Option), ) -> anyhow::Result> { // we need both byte ranges and chunk ranges. // chunk ranges to request data, and byte ranges to return the data. - tracing::debug!("forward_range {:?} {:?}", start, end); + tracing::debug!("forward_range {:?} {:?} (name {name:?})", start, end); let byte_ranges = to_byte_range(start, end); let chunk_ranges = to_chunk_range(start, end); tracing::debug!("got connection"); - let (_size, mime) = get_mime_type(gateway, hash, &connection).await?; + let (_size, mime) = get_mime_type(gateway, hash, name, &connection).await?; tracing::debug!("mime: {}", mime); let chunk_ranges = RangeSpecSeq::from_ranges(vec![chunk_ranges]); let request = iroh::bytes::protocol::GetRequest::new(*hash, chunk_ranges.clone());