mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
24/// connection
25pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
31    /// PostgreSQL connection
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39    oauth2_authorization_grant_id: Uuid,
40    created_at: DateTime<Utc>,
41    cancelled_at: Option<DateTime<Utc>>,
42    fulfilled_at: Option<DateTime<Utc>>,
43    exchanged_at: Option<DateTime<Utc>>,
44    scope: String,
45    state: Option<String>,
46    nonce: Option<String>,
47    redirect_uri: String,
48    response_mode: String,
49    response_type_code: bool,
50    response_type_id_token: bool,
51    authorization_code: Option<String>,
52    code_challenge: Option<String>,
53    code_challenge_method: Option<String>,
54    login_hint: Option<String>,
55    locale: Option<String>,
56    oauth2_client_id: Uuid,
57    oauth2_session_id: Option<Uuid>,
58}
59
60impl TryFrom<GrantLookup> for AuthorizationGrant {
61    type Error = DatabaseInconsistencyError;
62
63    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
64        let id = value.oauth2_authorization_grant_id.into();
65        let scope: Scope = value.scope.parse().map_err(|e| {
66            DatabaseInconsistencyError::on("oauth2_authorization_grants")
67                .column("scope")
68                .row(id)
69                .source(e)
70        })?;
71
72        let stage = match (
73            value.fulfilled_at,
74            value.exchanged_at,
75            value.cancelled_at,
76            value.oauth2_session_id,
77        ) {
78            (None, None, None, None) => AuthorizationGrantStage::Pending,
79            (Some(fulfilled_at), None, None, Some(session_id)) => {
80                AuthorizationGrantStage::Fulfilled {
81                    session_id: session_id.into(),
82                    fulfilled_at,
83                }
84            }
85            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
86                AuthorizationGrantStage::Exchanged {
87                    session_id: session_id.into(),
88                    fulfilled_at,
89                    exchanged_at,
90                }
91            }
92            (None, None, Some(cancelled_at), None) => {
93                AuthorizationGrantStage::Cancelled { cancelled_at }
94            }
95            _ => {
96                return Err(
97                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
98                        .column("stage")
99                        .row(id),
100                );
101            }
102        };
103
104        let pkce = match (value.code_challenge, value.code_challenge_method) {
105            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
106                Some(Pkce {
107                    challenge_method: PkceCodeChallengeMethod::Plain,
108                    challenge,
109                })
110            }
111            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
112                challenge_method: PkceCodeChallengeMethod::S256,
113                challenge,
114            }),
115            (None, None) => None,
116            _ => {
117                return Err(
118                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
119                        .column("code_challenge_method")
120                        .row(id),
121                );
122            }
123        };
124
125        let code: Option<AuthorizationCode> =
126            match (value.response_type_code, value.authorization_code, pkce) {
127                (false, None, None) => None,
128                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
129                _ => {
130                    return Err(
131                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
132                            .column("authorization_code")
133                            .row(id),
134                    );
135                }
136            };
137
138        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
139            DatabaseInconsistencyError::on("oauth2_authorization_grants")
140                .column("redirect_uri")
141                .row(id)
142                .source(e)
143        })?;
144
145        let response_mode = value.response_mode.parse().map_err(|e| {
146            DatabaseInconsistencyError::on("oauth2_authorization_grants")
147                .column("response_mode")
148                .row(id)
149                .source(e)
150        })?;
151
152        Ok(AuthorizationGrant {
153            id,
154            stage,
155            client_id: value.oauth2_client_id.into(),
156            code,
157            scope,
158            state: value.state,
159            nonce: value.nonce,
160            response_mode,
161            redirect_uri,
162            created_at: value.created_at,
163            response_type_id_token: value.response_type_id_token,
164            login_hint: value.login_hint,
165            locale: value.locale,
166        })
167    }
168}
169
170#[async_trait]
171impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
172    type Error = DatabaseError;
173
174    #[tracing::instrument(
175        name = "db.oauth2_authorization_grant.add",
176        skip_all,
177        fields(
178            db.query.text,
179            grant.id,
180            grant.scope = %scope,
181            %client.id,
182        ),
183        err,
184    )]
185    async fn add(
186        &mut self,
187        rng: &mut (dyn RngCore + Send),
188        clock: &dyn Clock,
189        client: &Client,
190        redirect_uri: Url,
191        scope: Scope,
192        code: Option<AuthorizationCode>,
193        state: Option<String>,
194        nonce: Option<String>,
195        response_mode: ResponseMode,
196        response_type_id_token: bool,
197        login_hint: Option<String>,
198        locale: Option<String>,
199    ) -> Result<AuthorizationGrant, Self::Error> {
200        let code_challenge = code
201            .as_ref()
202            .and_then(|c| c.pkce.as_ref())
203            .map(|p| &p.challenge);
204        let code_challenge_method = code
205            .as_ref()
206            .and_then(|c| c.pkce.as_ref())
207            .map(|p| p.challenge_method.to_string());
208        let code_str = code.as_ref().map(|c| &c.code);
209
210        let created_at = clock.now();
211        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
212        tracing::Span::current().record("grant.id", tracing::field::display(id));
213
214        sqlx::query!(
215            r#"
216                INSERT INTO oauth2_authorization_grants (
217                     oauth2_authorization_grant_id,
218                     oauth2_client_id,
219                     redirect_uri,
220                     scope,
221                     state,
222                     nonce,
223                     response_mode,
224                     code_challenge,
225                     code_challenge_method,
226                     response_type_code,
227                     response_type_id_token,
228                     authorization_code,
229                     login_hint,
230                     locale,
231                     created_at
232                )
233                VALUES
234                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
235            "#,
236            Uuid::from(id),
237            Uuid::from(client.id),
238            redirect_uri.to_string(),
239            scope.to_string(),
240            state,
241            nonce,
242            response_mode.to_string(),
243            code_challenge,
244            code_challenge_method,
245            code.is_some(),
246            response_type_id_token,
247            code_str,
248            login_hint,
249            locale,
250            created_at,
251        )
252        .traced()
253        .execute(&mut *self.conn)
254        .await?;
255
256        Ok(AuthorizationGrant {
257            id,
258            stage: AuthorizationGrantStage::Pending,
259            code,
260            redirect_uri,
261            client_id: client.id,
262            scope,
263            state,
264            nonce,
265            response_mode,
266            created_at,
267            response_type_id_token,
268            login_hint,
269            locale,
270        })
271    }
272
273    #[tracing::instrument(
274        name = "db.oauth2_authorization_grant.lookup",
275        skip_all,
276        fields(
277            db.query.text,
278            grant.id = %id,
279        ),
280        err,
281    )]
282    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
283        let res = sqlx::query_as!(
284            GrantLookup,
285            r#"
286                SELECT oauth2_authorization_grant_id
287                     , created_at
288                     , cancelled_at
289                     , fulfilled_at
290                     , exchanged_at
291                     , scope
292                     , state
293                     , redirect_uri
294                     , response_mode
295                     , nonce
296                     , oauth2_client_id
297                     , authorization_code
298                     , response_type_code
299                     , response_type_id_token
300                     , code_challenge
301                     , code_challenge_method
302                     , login_hint
303                     , locale
304                     , oauth2_session_id
305                FROM
306                    oauth2_authorization_grants
307
308                WHERE oauth2_authorization_grant_id = $1
309            "#,
310            Uuid::from(id),
311        )
312        .traced()
313        .fetch_optional(&mut *self.conn)
314        .await?;
315
316        let Some(res) = res else { return Ok(None) };
317
318        Ok(Some(res.try_into()?))
319    }
320
321    #[tracing::instrument(
322        name = "db.oauth2_authorization_grant.find_by_code",
323        skip_all,
324        fields(
325            db.query.text,
326        ),
327        err,
328    )]
329    async fn find_by_code(
330        &mut self,
331        code: &str,
332    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
333        let res = sqlx::query_as!(
334            GrantLookup,
335            r#"
336                SELECT oauth2_authorization_grant_id
337                     , created_at
338                     , cancelled_at
339                     , fulfilled_at
340                     , exchanged_at
341                     , scope
342                     , state
343                     , redirect_uri
344                     , response_mode
345                     , nonce
346                     , oauth2_client_id
347                     , authorization_code
348                     , response_type_code
349                     , response_type_id_token
350                     , code_challenge
351                     , code_challenge_method
352                     , login_hint
353                     , locale
354                     , oauth2_session_id
355                FROM
356                    oauth2_authorization_grants
357
358                WHERE authorization_code = $1
359            "#,
360            code,
361        )
362        .traced()
363        .fetch_optional(&mut *self.conn)
364        .await?;
365
366        let Some(res) = res else { return Ok(None) };
367
368        Ok(Some(res.try_into()?))
369    }
370
371    #[tracing::instrument(
372        name = "db.oauth2_authorization_grant.fulfill",
373        skip_all,
374        fields(
375            db.query.text,
376            %grant.id,
377            client.id = %grant.client_id,
378            %session.id,
379        ),
380        err,
381    )]
382    async fn fulfill(
383        &mut self,
384        clock: &dyn Clock,
385        session: &Session,
386        grant: AuthorizationGrant,
387    ) -> Result<AuthorizationGrant, Self::Error> {
388        let fulfilled_at = clock.now();
389        let res = sqlx::query!(
390            r#"
391                UPDATE oauth2_authorization_grants
392                SET fulfilled_at = $2
393                  , oauth2_session_id = $3
394                WHERE oauth2_authorization_grant_id = $1
395            "#,
396            Uuid::from(grant.id),
397            fulfilled_at,
398            Uuid::from(session.id),
399        )
400        .traced()
401        .execute(&mut *self.conn)
402        .await?;
403
404        DatabaseError::ensure_affected_rows(&res, 1)?;
405
406        // XXX: check affected rows & new methods
407        let grant = grant
408            .fulfill(fulfilled_at, session)
409            .map_err(DatabaseError::to_invalid_operation)?;
410
411        Ok(grant)
412    }
413
414    #[tracing::instrument(
415        name = "db.oauth2_authorization_grant.exchange",
416        skip_all,
417        fields(
418            db.query.text,
419            %grant.id,
420            client.id = %grant.client_id,
421        ),
422        err,
423    )]
424    async fn exchange(
425        &mut self,
426        clock: &dyn Clock,
427        grant: AuthorizationGrant,
428    ) -> Result<AuthorizationGrant, Self::Error> {
429        let exchanged_at = clock.now();
430        let res = sqlx::query!(
431            r#"
432                UPDATE oauth2_authorization_grants
433                SET exchanged_at = $2
434                WHERE oauth2_authorization_grant_id = $1
435            "#,
436            Uuid::from(grant.id),
437            exchanged_at,
438        )
439        .traced()
440        .execute(&mut *self.conn)
441        .await?;
442
443        DatabaseError::ensure_affected_rows(&res, 1)?;
444
445        let grant = grant
446            .exchange(exchanged_at)
447            .map_err(DatabaseError::to_invalid_operation)?;
448
449        Ok(grant)
450    }
451}