use crate::client::async_http_client; use clap::builder::{TryMapValueParser, TypedValueParser, ValueParserFactory}; use clio::CachedInput; use eyre::{ensure, Context, OptionExt}; use futures::future::try_join_all; use indexmap::IndexMap; use openidconnect::{ core::{CoreClient, CoreProviderMetadata}, ClientId, ClientSecret, IssuerUrl, RedirectUrl, Scope, }; use reqwest::Url; use serde::{Deserialize, Serialize}; use std::num::NonZeroU16; pub fn default_db_thread_channel_capacity() -> NonZeroU16 { NonZeroU16::new(10).unwrap() } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct Config { pub sqlite_db: String, #[serde(default = "default_db_thread_channel_capacity")] pub db_thread_channel_capacity: NonZeroU16, pub server_base_url: Url, pub oidc: IndexMap, } impl ValueParserFactory for Config { type Parser = TryMapValueParser< ::Parser, fn(CachedInput) -> Result, >; fn value_parser() -> Self::Parser { CachedInput::value_parser().try_map(|v| Config::load(&v).map_err(|e| format!("{e:#}"))) } } impl Config { #[allow(dead_code)] pub const EXAMPLE: &'static str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/config.toml.sample")); pub fn load_str(input: &str, path: impl ToString) -> eyre::Result { toml::from_str(input).wrap_err_with(|| path.to_string()) } pub fn load(input: &CachedInput) -> eyre::Result { log::debug!("loading config from: {}", input.path()); let s = std::str::from_utf8(input.get_data()).wrap_err_with(|| input.path().to_string())?; Self::load_str(s, input.path()) } pub async fn resolve(&mut self, resolve_provider_metadata: bool) -> eyre::Result<()> { let expected_server_base_url: Url = self .server_base_url .origin() .ascii_serialization() .parse() .ok() .ok_or_eyre("invalid server_base_url")?; ensure!( self.server_base_url == expected_server_base_url, "invalid server_base_url -- expected {expected_server_base_url}" ); try_join_all(self.oidc.iter_mut().map(|(name, provider)| { provider.resolve(name, resolve_provider_metadata, &self.server_base_url) })) .await?; Ok(()) } } #[derive(Debug, Clone)] pub struct OIDCProviderState { pub client: CoreClient, pub provider_metadata: CoreProviderMetadata, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct OIDCProvider { pub pretty_name: String, pub issuer_url: IssuerUrl, pub client_id: ClientId, pub secret: ClientSecret, pub redirect_url: RedirectUrl, pub scopes: Vec, #[serde(skip)] state: Option, } impl OIDCProvider { pub fn state(&self) -> &OIDCProviderState { self.state.as_ref().expect("resolve called by main") } pub async fn resolve( &mut self, name: &str, resolve_provider_metadata: bool, server_base_url: &Url, ) -> eyre::Result<()> { let expected_redirect_url = server_base_url.join(&format!("/subscription/callback/{name}"))?; ensure!( self.redirect_url.as_str() == expected_redirect_url.as_str(), "oidc.{name:?}.redirect_url should be {expected_redirect_url}" ); if resolve_provider_metadata { let provider_metadata = CoreProviderMetadata::discover_async(self.issuer_url.clone(), async_http_client) .await?; let client = CoreClient::from_provider_metadata( provider_metadata.clone(), self.client_id.clone(), Some(self.secret.clone()), ); let client = client.disable_openid_scope(); let client = client.set_redirect_uri(self.redirect_url.clone()); self.state = Some(OIDCProviderState { client, provider_metadata, }); } Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_example() -> eyre::Result<()> { Config::load_str(Config::EXAMPLE, "config.toml.sample")?; Ok(()) } }