subscribe-list/src/config.rs
2024-04-08 23:39:40 -07:00

136 lines
4.3 KiB
Rust

use crate::client::async_http_client;
use clap::builder::{TryMapValueParser, TypedValueParser, ValueParserFactory};
use clio::CachedInput;
use eyre::{ensure, Context, OptionExt};
use futures::future::try_join_all;
use indexmap::IndexMap;
use openidconnect::{
core::{CoreClient, CoreProviderMetadata},
ClientId, ClientSecret, IssuerUrl, RedirectUrl, Scope,
};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::num::NonZeroU16;
pub fn default_db_thread_channel_capacity() -> NonZeroU16 {
NonZeroU16::new(10).unwrap()
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Config {
pub sqlite_db: String,
#[serde(default = "default_db_thread_channel_capacity")]
pub db_thread_channel_capacity: NonZeroU16,
pub server_base_url: Url,
pub oidc: IndexMap<String, OIDCProvider>,
}
impl ValueParserFactory for Config {
type Parser = TryMapValueParser<
<CachedInput as ValueParserFactory>::Parser,
fn(CachedInput) -> Result<Config, String>,
>;
fn value_parser() -> Self::Parser {
CachedInput::value_parser().try_map(|v| Config::load(&v).map_err(|e| format!("{e:#}")))
}
}
impl Config {
#[allow(dead_code)]
pub const EXAMPLE: &'static str =
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/config.toml.sample"));
pub fn load_str(input: &str, path: impl ToString) -> eyre::Result<Config> {
toml::from_str(input).wrap_err_with(|| path.to_string())
}
pub fn load(input: &CachedInput) -> eyre::Result<Config> {
log::debug!("loading config from: {}", input.path());
let s = std::str::from_utf8(input.get_data()).wrap_err_with(|| input.path().to_string())?;
Self::load_str(s, input.path())
}
pub async fn resolve(&mut self, resolve_provider_metadata: bool) -> eyre::Result<()> {
let expected_server_base_url: Url = self
.server_base_url
.origin()
.ascii_serialization()
.parse()
.ok()
.ok_or_eyre("invalid server_base_url")?;
ensure!(
self.server_base_url == expected_server_base_url,
"invalid server_base_url -- expected {expected_server_base_url}"
);
try_join_all(self.oidc.iter_mut().map(|(name, provider)| {
provider.resolve(name, resolve_provider_metadata, &self.server_base_url)
}))
.await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OIDCProviderState {
pub client: CoreClient,
pub provider_metadata: CoreProviderMetadata,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OIDCProvider {
pub pretty_name: String,
pub issuer_url: IssuerUrl,
pub client_id: ClientId,
pub secret: ClientSecret,
pub redirect_url: RedirectUrl,
pub scopes: Vec<Scope>,
#[serde(skip)]
state: Option<OIDCProviderState>,
}
impl OIDCProvider {
pub fn state(&self) -> &OIDCProviderState {
self.state.as_ref().expect("resolve called by main")
}
pub async fn resolve(
&mut self,
name: &str,
resolve_provider_metadata: bool,
server_base_url: &Url,
) -> eyre::Result<()> {
let expected_redirect_url =
server_base_url.join(&format!("/subscription/callback/{name}"))?;
ensure!(
self.redirect_url.as_str() == expected_redirect_url.as_str(),
"oidc.{name:?}.redirect_url should be {expected_redirect_url}"
);
if resolve_provider_metadata {
let provider_metadata =
CoreProviderMetadata::discover_async(self.issuer_url.clone(), async_http_client)
.await?;
let client = CoreClient::from_provider_metadata(
provider_metadata.clone(),
self.client_id.clone(),
Some(self.secret.clone()),
);
let client = client.disable_openid_scope();
let client = client.set_redirect_uri(self.redirect_url.clone());
self.state = Some(OIDCProviderState {
client,
provider_metadata,
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_example() -> eyre::Result<()> {
Config::load_str(Config::EXAMPLE, "config.toml.sample")?;
Ok(())
}
}