1use std::ops::Deref;
8
9use anyhow::bail;
10use camino::Utf8PathBuf;
11use mas_iana::oauth::OAuthClientAuthenticationMethod;
12use mas_jose::jwk::PublicJsonWebKeySet;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize, de::Error};
15use serde_with::serde_as;
16use ulid::Ulid;
17use url::Url;
18
19use super::ConfigurationSection;
20
21#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
22#[serde(rename_all = "snake_case")]
23pub enum JwksOrJwksUri {
24 Jwks(PublicJsonWebKeySet),
25 JwksUri(Url),
26}
27
28impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
29 fn from(jwks: PublicJsonWebKeySet) -> Self {
30 Self::Jwks(jwks)
31 }
32}
33
34#[derive(Clone, Debug)]
39pub enum ClientSecret {
40 File(Utf8PathBuf),
41 Value(String),
42}
43
44#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
46struct ClientSecretRaw {
47 #[schemars(with = "Option<String>")]
51 #[serde(skip_serializing_if = "Option::is_none")]
52 client_secret_file: Option<Utf8PathBuf>,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
57 client_secret: Option<String>,
58}
59
60impl TryFrom<ClientSecretRaw> for Option<ClientSecret> {
61 type Error = anyhow::Error;
62
63 fn try_from(value: ClientSecretRaw) -> Result<Self, Self::Error> {
64 match (value.client_secret, value.client_secret_file) {
65 (None, None) => Ok(None),
66 (None, Some(path)) => Ok(Some(ClientSecret::File(path))),
67 (Some(client_secret), None) => Ok(Some(ClientSecret::Value(client_secret))),
68 (Some(_), Some(_)) => {
69 bail!("Cannot specify both `client_secret` and `client_secret_file`")
70 }
71 }
72 }
73}
74
75impl From<Option<ClientSecret>> for ClientSecretRaw {
76 fn from(value: Option<ClientSecret>) -> Self {
77 match value {
78 Some(ClientSecret::File(path)) => ClientSecretRaw {
79 client_secret_file: Some(path),
80 client_secret: None,
81 },
82 Some(ClientSecret::Value(client_secret)) => ClientSecretRaw {
83 client_secret_file: None,
84 client_secret: Some(client_secret),
85 },
86 None => ClientSecretRaw {
87 client_secret_file: None,
88 client_secret: None,
89 },
90 }
91 }
92}
93
94#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
96#[serde(rename_all = "snake_case")]
97pub enum ClientAuthMethodConfig {
98 None,
100
101 ClientSecretBasic,
104
105 ClientSecretPost,
108
109 ClientSecretJwt,
112
113 PrivateKeyJwt,
116}
117
118impl std::fmt::Display for ClientAuthMethodConfig {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 match self {
121 ClientAuthMethodConfig::None => write!(f, "none"),
122 ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
123 ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
124 ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
125 ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
126 }
127 }
128}
129
130#[serde_as]
132#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
133pub struct ClientConfig {
134 #[schemars(
136 with = "String",
137 regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
138 description = "A ULID as per https://github.com/ulid/spec"
139 )]
140 pub client_id: Ulid,
141
142 client_auth_method: ClientAuthMethodConfig,
144
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub client_name: Option<String>,
148
149 #[schemars(with = "ClientSecretRaw")]
152 #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
153 #[serde(flatten)]
154 pub client_secret: Option<ClientSecret>,
155
156 #[serde(skip_serializing_if = "Option::is_none")]
159 pub jwks: Option<PublicJsonWebKeySet>,
160
161 #[serde(skip_serializing_if = "Option::is_none")]
164 pub jwks_uri: Option<Url>,
165
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
168 pub redirect_uris: Vec<Url>,
169}
170
171impl ClientConfig {
172 fn validate(&self) -> Result<(), Box<figment::error::Error>> {
173 let auth_method = self.client_auth_method;
174 match self.client_auth_method {
175 ClientAuthMethodConfig::PrivateKeyJwt => {
176 if self.jwks.is_none() && self.jwks_uri.is_none() {
177 let error = figment::error::Error::custom(
178 "jwks or jwks_uri is required for private_key_jwt",
179 );
180 return Err(Box::new(error.with_path("client_auth_method")));
181 }
182
183 if self.jwks.is_some() && self.jwks_uri.is_some() {
184 let error =
185 figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
186 return Err(Box::new(error.with_path("jwks")));
187 }
188
189 if self.client_secret.is_some() {
190 let error = figment::error::Error::custom(
191 "client_secret is not allowed with private_key_jwt",
192 );
193 return Err(Box::new(error.with_path("client_secret")));
194 }
195 }
196
197 ClientAuthMethodConfig::ClientSecretPost
198 | ClientAuthMethodConfig::ClientSecretBasic
199 | ClientAuthMethodConfig::ClientSecretJwt => {
200 if self.client_secret.is_none() {
201 let error = figment::error::Error::custom(format!(
202 "client_secret is required for {auth_method}"
203 ));
204 return Err(Box::new(error.with_path("client_auth_method")));
205 }
206
207 if self.jwks.is_some() {
208 let error = figment::error::Error::custom(format!(
209 "jwks is not allowed with {auth_method}"
210 ));
211 return Err(Box::new(error.with_path("jwks")));
212 }
213
214 if self.jwks_uri.is_some() {
215 let error = figment::error::Error::custom(format!(
216 "jwks_uri is not allowed with {auth_method}"
217 ));
218 return Err(Box::new(error.with_path("jwks_uri")));
219 }
220 }
221
222 ClientAuthMethodConfig::None => {
223 if self.client_secret.is_some() {
224 let error = figment::error::Error::custom(
225 "client_secret is not allowed with none authentication method",
226 );
227 return Err(Box::new(error.with_path("client_secret")));
228 }
229
230 if self.jwks.is_some() {
231 let error = figment::error::Error::custom(
232 "jwks is not allowed with none authentication method",
233 );
234 return Err(Box::new(error));
235 }
236
237 if self.jwks_uri.is_some() {
238 let error = figment::error::Error::custom(
239 "jwks_uri is not allowed with none authentication method",
240 );
241 return Err(Box::new(error));
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 #[must_use]
251 pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
252 match self.client_auth_method {
253 ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
254 ClientAuthMethodConfig::ClientSecretBasic => {
255 OAuthClientAuthenticationMethod::ClientSecretBasic
256 }
257 ClientAuthMethodConfig::ClientSecretPost => {
258 OAuthClientAuthenticationMethod::ClientSecretPost
259 }
260 ClientAuthMethodConfig::ClientSecretJwt => {
261 OAuthClientAuthenticationMethod::ClientSecretJwt
262 }
263 ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
264 }
265 }
266
267 pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
275 Ok(match &self.client_secret {
276 Some(ClientSecret::File(path)) => Some(tokio::fs::read_to_string(path).await?),
277 Some(ClientSecret::Value(client_secret)) => Some(client_secret.clone()),
278 None => None,
279 })
280 }
281}
282
283#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
285#[serde(transparent)]
286pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
287
288impl ClientsConfig {
289 pub(crate) fn is_default(&self) -> bool {
291 self.0.is_empty()
292 }
293}
294
295impl Deref for ClientsConfig {
296 type Target = Vec<ClientConfig>;
297
298 fn deref(&self) -> &Self::Target {
299 &self.0
300 }
301}
302
303impl IntoIterator for ClientsConfig {
304 type Item = ClientConfig;
305 type IntoIter = std::vec::IntoIter<ClientConfig>;
306
307 fn into_iter(self) -> Self::IntoIter {
308 self.0.into_iter()
309 }
310}
311
312impl ConfigurationSection for ClientsConfig {
313 const PATH: Option<&'static str> = Some("clients");
314
315 fn validate(
316 &self,
317 figment: &figment::Figment,
318 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
319 for (index, client) in self.0.iter().enumerate() {
320 client.validate().map_err(|mut err| {
321 err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
323 err.profile = Some(figment::Profile::Default);
324 err.path.insert(0, Self::PATH.unwrap().to_owned());
325 err.path.insert(1, format!("{index}"));
326 err
327 })?;
328 }
329
330 Ok(())
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use std::str::FromStr;
337
338 use figment::{
339 Figment, Jail,
340 providers::{Format, Yaml},
341 };
342 use tokio::{runtime::Handle, task};
343
344 use super::*;
345
346 #[tokio::test]
347 async fn load_config() {
348 task::spawn_blocking(|| {
349 Jail::expect_with(|jail| {
350 jail.create_file(
351 "config.yaml",
352 r#"
353 clients:
354 - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
355 client_auth_method: none
356 redirect_uris:
357 - https://exemple.fr/callback
358
359 - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
360 client_auth_method: client_secret_basic
361 client_secret_file: secret
362
363 - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
364 client_auth_method: client_secret_post
365 client_secret: c1!3n753c237
366
367 - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
368 client_auth_method: client_secret_jwt
369 client_secret_file: secret
370
371 - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
372 client_auth_method: private_key_jwt
373 jwks:
374 keys:
375 - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
376 kty: "RSA"
377 alg: "RS256"
378 use: "sig"
379 e: "AQAB"
380 n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
381
382 - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
383 kty: "RSA"
384 alg: "RS256"
385 use: "sig"
386 e: "AQAB"
387 n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
388 "#,
389 )?;
390 jail.create_file("secret", r"c1!3n753c237")?;
391
392 let config = Figment::new()
393 .merge(Yaml::file("config.yaml"))
394 .extract_inner::<ClientsConfig>("clients")?;
395
396 assert_eq!(config.0.len(), 5);
397
398 assert_eq!(
399 config.0[0].client_id,
400 Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
401 );
402 assert_eq!(
403 config.0[0].redirect_uris,
404 vec!["https://exemple.fr/callback".parse().unwrap()]
405 );
406
407 assert_eq!(
408 config.0[1].client_id,
409 Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
410 );
411 assert_eq!(config.0[1].redirect_uris, Vec::new());
412
413 assert!(config.0[0].client_secret.is_none());
414 assert!(matches!(config.0[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
415 assert!(matches!(config.0[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
416 assert!(matches!(config.0[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
417 assert!(config.0[4].client_secret.is_none());
418
419 Handle::current().block_on(async move {
420 assert_eq!(config.0[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
421 assert_eq!(config.0[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
422 assert_eq!(config.0[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
423 });
424
425 Ok(())
426 });
427 }).await.unwrap();
428 }
429}