Skip to content

Commit

Permalink
feat: specify local endpoints as source (#47)
Browse files Browse the repository at this point in the history
Fix #36.

Note: this feature depends on correct `X-Forwarded-Host` and
`X-Forwarded-Proto` headers.

---------

Co-authored-by: shouya <[email protected]>
  • Loading branch information
shouya and shouya authored Feb 22, 2024
1 parent 144874d commit 42b8585
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 13 deletions.
7 changes: 6 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pub struct Cli {
#[derive(Parser)]
enum SubCommand {
Server(ServerConfig),
Test(TestConfig),
// boxed because of the clippy::large_enum_variant warning
Test(Box<TestConfig>),
}

#[derive(Parser)]
Expand All @@ -44,6 +45,9 @@ struct TestConfig {
/// Don't print XML output (Useful for checking console.log in JS filters)
#[clap(long, short)]
quiet: bool,
/// The base URL of the feed, used for resolving relative urls
#[clap(long)]
base: Option<Url>,
}

impl TestConfig {
Expand All @@ -53,6 +57,7 @@ impl TestConfig {
self.limit_filters,
self.limit_posts,
!self.compact_output,
self.base.clone(),
)
}
}
Expand Down
24 changes: 23 additions & 1 deletion src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,46 @@ mod simplify_html;
use std::sync::Arc;

use serde::{Deserialize, Serialize};
use url::Url;

use crate::{feed::Feed, util::Result};

#[derive(Clone)]
pub struct FilterContext {
pub(crate) limit_filters: Option<usize>,
limit_filters: Option<usize>,
/// The base URL of the application. Used to construct absolute URLs
/// from a relative path.
base: Option<Url>,
}

impl FilterContext {
pub fn new() -> Self {
Self {
limit_filters: None,
base: None,
}
}

pub fn limit_filters(&self) -> Option<usize> {
self.limit_filters
}

pub fn base(&self) -> Option<&Url> {
self.base.as_ref()
}

pub fn set_limit_filters(&mut self, limit: usize) {
self.limit_filters = Some(limit);
}

pub fn set_base(&mut self, base: Url) {
self.base = Some(base);
}

pub fn subcontext(&self) -> Self {
Self {
limit_filters: None,
base: self.base.clone(),
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/filter/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ pub struct Merge {
#[async_trait::async_trait]
impl FeedFilter for Merge {
async fn run(&self, ctx: &mut FilterContext, mut feed: Feed) -> Result<Feed> {
let new_feed = self.source.fetch_feed(Some(&self.client), None).await?;
let base = ctx.base();
let new_feed = self.source.fetch_feed(Some(&self.client), base).await?;
let ctx = ctx.subcontext();
let filtered_new_feed = self.filters.run(ctx, new_feed).await?;
feed.merge(filtered_new_feed)?;
Expand Down
5 changes: 3 additions & 2 deletions src/filter_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ impl FilterPipeline {
mut context: FilterContext,
mut feed: Feed,
) -> Result<Feed> {
let limit_filters =
context.limit_filters.unwrap_or_else(|| self.num_filters());
let limit_filters = context
.limit_filters()
.unwrap_or_else(|| self.num_filters());
for filter in self.filters.iter().take(limit_filters) {
feed = filter.run(&mut context, feed).await?;
}
Expand Down
35 changes: 33 additions & 2 deletions src/server/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::time::Duration;
use axum::body::Body;
use axum::response::IntoResponse;
use axum_macros::FromRequestParts;
use http::header::HOST;
use http::StatusCode;
use mime::Mime;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -77,6 +78,8 @@ pub struct EndpointParam {
/// Limit the number of items in the feed
limit_posts: Option<usize>,
pretty_print: bool,
/// The url base of the feed, used for resolving relative urls
base: Option<Url>,
}

impl EndpointParam {
Expand All @@ -85,12 +88,14 @@ impl EndpointParam {
limit_filters: Option<usize>,
limit_posts: Option<usize>,
pretty_print: bool,
base: Option<Url>,
) -> Self {
Self {
source,
limit_filters,
limit_posts,
pretty_print,
base,
}
}

Expand All @@ -100,6 +105,7 @@ impl EndpointParam {
limit_filters: Self::parse_limit_filters(req),
limit_posts: Self::parse_limit_posts(req),
pretty_print: Self::parse_pretty_print(req),
base: Self::get_base(req),
}
}

Expand All @@ -121,6 +127,24 @@ impl EndpointParam {
.unwrap_or(false)
}

fn get_base(req: &Request) -> Option<Url> {
let host = req
.headers()
.get("X-Forwarded-Host")
.or_else(|| req.headers().get(HOST))
.and_then(|x| x.to_str().ok())?;

let proto = req
.headers()
.get("X-Forwarded-Proto")
.and_then(|x| x.to_str().ok())
.unwrap_or("http");

let base = format!("{proto}://{host}/");
let base = base.parse().ok()?;
Some(base)
}

fn get_query(req: &Request, name: &str) -> Option<String> {
let url = Url::parse(&format!("http://placeholder{}", &req.uri())).ok()?;
url
Expand Down Expand Up @@ -246,9 +270,16 @@ impl EndpointService {
param: EndpointParam,
) -> Result<EndpointOutcome> {
let source = self.find_source(&param.source)?;
let feed = source.fetch_feed(Some(&self.client), None).await?;
let feed = source
.fetch_feed(Some(&self.client), param.base.as_ref())
.await?;
let mut context = FilterContext::new();
context.limit_filters = param.limit_filters;
if let Some(limit_filters) = param.limit_filters {
context.set_limit_filters(limit_filters);
}
if let Some(base) = param.base {
context.set_base(base);
}

let mut feed = self.filters.run(context, feed).await?;

Expand Down
10 changes: 4 additions & 6 deletions src/source.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use http::request::Parts;
use serde::{Deserialize, Serialize};
use url::Url;

Expand Down Expand Up @@ -57,7 +56,7 @@ impl Source {
pub async fn fetch_feed(
&self,
client: Option<&Client>,
request: Option<&Parts>,
base: Option<&Url>,
) -> Result<Feed> {
if let Source::FromScratch(config) = self {
let feed = Feed::from(config);
Expand All @@ -69,10 +68,9 @@ impl Source {
let source_url = match self {
Source::AbsoluteUrl(url) => url.clone(),
Source::RelativeUrl(path) => {
let request =
request.ok_or_else(|| Error::Message("request not set".into()))?;
let this_url: Url = request.uri.to_string().parse()?;
this_url.join(path)?
let base =
base.ok_or_else(|| Error::Message("base_url not set".into()))?;
base.join(path)?
}
Source::FromScratch(_) => unreachable!(),
};
Expand Down

0 comments on commit 42b8585

Please sign in to comment.