diff --git a/Cargo.lock b/Cargo.lock index 09d7ba4..b03dd48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2287,6 +2287,7 @@ dependencies = [ "futures", "indexmap 2.2.6", "listenfd", + "log", "openidconnect", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 132b3e3..7dfd291 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ eyre = "0.6.12" futures = "0.3.30" indexmap = { version = "2.2.6", features = ["serde"] } listenfd = "1.0.1" +log = "0.4.21" openidconnect = "3.5.0" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" diff --git a/config.toml.sample b/config.toml.sample index 9713ecb..da14b88 100644 --- a/config.toml.sample +++ b/config.toml.sample @@ -3,6 +3,7 @@ pretty_name = "Google" issuer_url = "https://accounts.google.com" client_id = "" secret = "" +redirect_url = "https://my-site/subscription/callback/google" scopes = ["email"] [oidc.debian-salsa] @@ -10,4 +11,5 @@ pretty_name = "Debian Salsa" issuer_url = "https://salsa.debian.org" client_id = "" secret = "" +redirect_url = "https://my-site/subscription/callback/debian-salsa" scopes = ["email"] diff --git a/src/app.rs b/src/app.rs index 7ff156a..edac951 100644 --- a/src/app.rs +++ b/src/app.rs @@ -5,16 +5,18 @@ use actix_web::{ get, http::{ header::{ContentType, LOCATION}, + uri::PathAndQuery, StatusCode, Uri, }, web::{self, ServiceConfig}, HttpResponse, HttpResponseBuilder, Responder, }; use openidconnect::{ - core::CoreAuthenticationFlow, http::HeaderValue, CsrfToken, EndUserEmail, Nonce, + core::CoreAuthenticationFlow, http::HeaderValue, reqwest::async_http_client, url::Url, + AuthorizationCode, CsrfToken, EndUserEmail, Nonce, RedirectUrl, TokenResponse, }; use serde::{Deserialize, Serialize}; -use std::{fmt::Write, marker::PhantomData}; +use std::{borrow::Cow, fmt::Write, marker::PhantomData}; use tinytemplate::TinyTemplate; pub trait DynTemplate { @@ -125,12 +127,76 @@ impl OIDCProvider {} const NONCE: &'static str = "nonce"; const CSRF_TOKEN: &'static str = "csrf_token"; +#[derive(Deserialize)] +struct AuthQuery { + code: AuthorizationCode, + state: CsrfToken, +} + +#[get("/subscription/callback/{provider}")] +pub async fn callback( + config: web::Data, + session: Session, + path: web::Path<(String,)>, + query: web::Query, +) -> impl Responder { + let mut resp = HttpResponse::SeeOther().body(""); + resp.headers_mut() + .insert(LOCATION, "/subscription".parse().unwrap()); + if SessionState::get(&session).is_some() { + return resp; + } + let Ok(Some(csrf_token)) = session.get::(CSRF_TOKEN) else { + return resp; + }; + let Ok(Some(nonce)) = session.get::(NONCE) else { + return resp; + }; + session.remove(CSRF_TOKEN); + session.remove(NONCE); + if csrf_token.secret() != query.state.secret() { + return resp; + } + let Some(provider) = config.oidc.get(&path.0) else { + return resp; + }; + let token = match provider + .state() + .client + .exchange_code(query.into_inner().code) + .request_async(async_http_client) + .await + { + Ok(token) => token, + Err(e) => { + log::warn!("Error getting token: {e}"); + return resp; + } + }; + dbg!(&token); + let claims = match token + .id_token() + .unwrap() + .claims(&provider.state().client.id_token_verifier(), &nonce) + { + Ok(claims) => claims, + Err(e) => { + log::warn!("Error verifying token: {e}"); + return resp; + } + }; + let Some(email) = claims.email().cloned() else { + return resp; + }; + SessionState { email }.set(&session); + resp +} + #[get("/subscription/login/{provider}")] pub async fn login( config: web::Data, session: Session, path: web::Path<(String,)>, - uri: Uri, ) -> impl Responder { let mut resp = HttpResponse::SeeOther().body(""); resp.headers_mut() @@ -179,7 +245,7 @@ pub async fn subscription(config: web::Data, session: Session) -> impl R } pub fn all_services(cfg: &mut ServiceConfig) { - cfg.service(subscription).service(login); + cfg.service(subscription).service(login).service(callback); } #[cfg(test)] diff --git a/src/config.rs b/src/config.rs index 331374b..f45ee18 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,7 +5,7 @@ use futures::future::try_join_all; use indexmap::IndexMap; use openidconnect::{ core::{CoreClient, CoreProviderMetadata}, - ClientId, ClientSecret, IssuerUrl, Scope, + ClientId, ClientSecret, IssuerUrl, RedirectUrl, Scope, }; use serde::{Deserialize, Serialize}; @@ -55,6 +55,7 @@ pub struct OIDCProvider { pub issuer_url: IssuerUrl, pub client_id: ClientId, pub secret: ClientSecret, + pub redirect_url: RedirectUrl, pub scopes: Vec, #[serde(skip)] state: Option, @@ -76,6 +77,7 @@ impl OIDCProvider { Some(self.secret.clone()), ); let client = client.disable_openid_scope(); + let client = client.set_redirect_uri(self.redirect_url.clone()); self.state = Some(OIDCProviderState { client, provider_metadata,