it all works!

This commit is contained in:
Jacob Lifshay 2024-04-08 23:38:50 -07:00
parent d6ebd3a4a6
commit 40e8445848
Signed by: programmerjake
SSH key fingerprint: SHA256:B1iRVvUJkvd7upMIiMqn6OyxvD2SgJkAH3ZnUOj6z+c
13 changed files with 471 additions and 47 deletions

View file

@ -1,13 +1,14 @@
use crate::client::async_http_client;
use clap::builder::{TryMapValueParser, TypedValueParser, ValueParserFactory};
use clio::CachedInput;
use eyre::Context;
use eyre::{ensure, Context, OptionExt};
use futures::future::try_join_all;
use indexmap::IndexMap;
use openidconnect::{
core::{CoreClient, CoreProviderMetadata},
ClientId, ClientSecret, IssuerUrl, RedirectUrl, Scope,
};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::num::NonZeroU16;
@ -20,17 +21,18 @@ pub struct Config {
pub sqlite_db: String,
#[serde(default = "default_db_thread_channel_capacity")]
pub db_thread_channel_capacity: NonZeroU16,
pub server_base_url: Url,
pub oidc: IndexMap<String, OIDCProvider>,
}
impl ValueParserFactory for Config {
type Parser = TryMapValueParser<
<CachedInput as ValueParserFactory>::Parser,
fn(CachedInput) -> eyre::Result<Config>,
fn(CachedInput) -> Result<Config, String>,
>;
fn value_parser() -> Self::Parser {
CachedInput::value_parser().try_map(Config::load)
CachedInput::value_parser().try_map(|v| Config::load(&v).map_err(|e| format!("{e:#}")))
}
}
@ -42,13 +44,28 @@ impl Config {
toml::from_str(input).wrap_err_with(|| path.to_string())
}
pub fn load(input: CachedInput) -> eyre::Result<Config> {
pub fn load(input: &CachedInput) -> eyre::Result<Config> {
log::debug!("loading config from: {}", input.path());
let s = std::str::from_utf8(input.get_data()).wrap_err_with(|| input.path().to_string())?;
Self::load_str(s, input.path())
}
pub async fn resolve(&mut self) -> eyre::Result<()> {
try_join_all(self.oidc.iter_mut().map(|(_, provider)| provider.resolve())).await?;
pub async fn resolve(&mut self, resolve_provider_metadata: bool) -> eyre::Result<()> {
let expected_server_base_url: Url = self
.server_base_url
.origin()
.ascii_serialization()
.parse()
.ok()
.ok_or_eyre("invalid server_base_url")?;
ensure!(
self.server_base_url == expected_server_base_url,
"invalid server_base_url -- expected {expected_server_base_url}"
);
try_join_all(self.oidc.iter_mut().map(|(name, provider)| {
provider.resolve(name, resolve_provider_metadata, &self.server_base_url)
}))
.await?;
Ok(())
}
}
@ -75,21 +92,34 @@ impl OIDCProvider {
pub fn state(&self) -> &OIDCProviderState {
self.state.as_ref().expect("resolve called by main")
}
pub async fn resolve(&mut self) -> eyre::Result<()> {
let provider_metadata =
CoreProviderMetadata::discover_async(self.issuer_url.clone(), async_http_client)
.await?;
let client = CoreClient::from_provider_metadata(
provider_metadata.clone(),
self.client_id.clone(),
Some(self.secret.clone()),
pub async fn resolve(
&mut self,
name: &str,
resolve_provider_metadata: bool,
server_base_url: &Url,
) -> eyre::Result<()> {
let expected_redirect_url =
server_base_url.join(&format!("/subscription/callback/{name}"))?;
ensure!(
self.redirect_url.as_str() == expected_redirect_url.as_str(),
"oidc.{name:?}.redirect_url should be {expected_redirect_url}"
);
let client = client.disable_openid_scope();
let client = client.set_redirect_uri(self.redirect_url.clone());
self.state = Some(OIDCProviderState {
client,
provider_metadata,
});
if resolve_provider_metadata {
let provider_metadata =
CoreProviderMetadata::discover_async(self.issuer_url.clone(), async_http_client)
.await?;
let client = CoreClient::from_provider_metadata(
provider_metadata.clone(),
self.client_id.clone(),
Some(self.secret.clone()),
);
let client = client.disable_openid_scope();
let client = client.set_redirect_uri(self.redirect_url.clone());
self.state = Some(OIDCProviderState {
client,
provider_metadata,
});
}
Ok(())
}
}