mas_storage_pg/oauth2/
client.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    string::ToString,
10};
11
12use async_trait::async_trait;
13use mas_data_model::{Client, JwksOrJwksUri};
14use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
15use mas_jose::jwk::PublicJsonWebKeySet;
16use mas_storage::{Clock, oauth2::OAuth2ClientRepository};
17use oauth2_types::{oidc::ApplicationType, requests::GrantType};
18use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
19use rand::RngCore;
20use sqlx::PgConnection;
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use url::Url;
24use uuid::Uuid;
25
26use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
27
28/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
29pub struct PgOAuth2ClientRepository<'c> {
30    conn: &'c mut PgConnection,
31}
32
33impl<'c> PgOAuth2ClientRepository<'c> {
34    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
35    /// connection
36    pub fn new(conn: &'c mut PgConnection) -> Self {
37        Self { conn }
38    }
39}
40
41#[allow(clippy::struct_excessive_bools)]
42#[derive(Debug)]
43struct OAuth2ClientLookup {
44    oauth2_client_id: Uuid,
45    metadata_digest: Option<String>,
46    encrypted_client_secret: Option<String>,
47    application_type: Option<String>,
48    redirect_uris: Vec<String>,
49    grant_type_authorization_code: bool,
50    grant_type_refresh_token: bool,
51    grant_type_client_credentials: bool,
52    grant_type_device_code: bool,
53    client_name: Option<String>,
54    logo_uri: Option<String>,
55    client_uri: Option<String>,
56    policy_uri: Option<String>,
57    tos_uri: Option<String>,
58    jwks_uri: Option<String>,
59    jwks: Option<serde_json::Value>,
60    id_token_signed_response_alg: Option<String>,
61    userinfo_signed_response_alg: Option<String>,
62    token_endpoint_auth_method: Option<String>,
63    token_endpoint_auth_signing_alg: Option<String>,
64    initiate_login_uri: Option<String>,
65}
66
67impl TryInto<Client> for OAuth2ClientLookup {
68    type Error = DatabaseInconsistencyError;
69
70    fn try_into(self) -> Result<Client, Self::Error> {
71        let id = Ulid::from(self.oauth2_client_id);
72
73        let redirect_uris: Result<Vec<Url>, _> =
74            self.redirect_uris.iter().map(|s| s.parse()).collect();
75        let redirect_uris = redirect_uris.map_err(|e| {
76            DatabaseInconsistencyError::on("oauth2_clients")
77                .column("redirect_uris")
78                .row(id)
79                .source(e)
80        })?;
81
82        let application_type = self
83            .application_type
84            .map(|s| s.parse())
85            .transpose()
86            .map_err(|e| {
87                DatabaseInconsistencyError::on("oauth2_clients")
88                    .column("application_type")
89                    .row(id)
90                    .source(e)
91            })?;
92
93        let mut grant_types = Vec::new();
94        if self.grant_type_authorization_code {
95            grant_types.push(GrantType::AuthorizationCode);
96        }
97        if self.grant_type_refresh_token {
98            grant_types.push(GrantType::RefreshToken);
99        }
100        if self.grant_type_client_credentials {
101            grant_types.push(GrantType::ClientCredentials);
102        }
103        if self.grant_type_device_code {
104            grant_types.push(GrantType::DeviceCode);
105        }
106
107        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
108            DatabaseInconsistencyError::on("oauth2_clients")
109                .column("logo_uri")
110                .row(id)
111                .source(e)
112        })?;
113
114        let client_uri = self
115            .client_uri
116            .map(|s| s.parse())
117            .transpose()
118            .map_err(|e| {
119                DatabaseInconsistencyError::on("oauth2_clients")
120                    .column("client_uri")
121                    .row(id)
122                    .source(e)
123            })?;
124
125        let policy_uri = self
126            .policy_uri
127            .map(|s| s.parse())
128            .transpose()
129            .map_err(|e| {
130                DatabaseInconsistencyError::on("oauth2_clients")
131                    .column("policy_uri")
132                    .row(id)
133                    .source(e)
134            })?;
135
136        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
137            DatabaseInconsistencyError::on("oauth2_clients")
138                .column("tos_uri")
139                .row(id)
140                .source(e)
141        })?;
142
143        let id_token_signed_response_alg = self
144            .id_token_signed_response_alg
145            .map(|s| s.parse())
146            .transpose()
147            .map_err(|e| {
148                DatabaseInconsistencyError::on("oauth2_clients")
149                    .column("id_token_signed_response_alg")
150                    .row(id)
151                    .source(e)
152            })?;
153
154        let userinfo_signed_response_alg = self
155            .userinfo_signed_response_alg
156            .map(|s| s.parse())
157            .transpose()
158            .map_err(|e| {
159                DatabaseInconsistencyError::on("oauth2_clients")
160                    .column("userinfo_signed_response_alg")
161                    .row(id)
162                    .source(e)
163            })?;
164
165        let token_endpoint_auth_method = self
166            .token_endpoint_auth_method
167            .map(|s| s.parse())
168            .transpose()
169            .map_err(|e| {
170                DatabaseInconsistencyError::on("oauth2_clients")
171                    .column("token_endpoint_auth_method")
172                    .row(id)
173                    .source(e)
174            })?;
175
176        let token_endpoint_auth_signing_alg = self
177            .token_endpoint_auth_signing_alg
178            .map(|s| s.parse())
179            .transpose()
180            .map_err(|e| {
181                DatabaseInconsistencyError::on("oauth2_clients")
182                    .column("token_endpoint_auth_signing_alg")
183                    .row(id)
184                    .source(e)
185            })?;
186
187        let initiate_login_uri = self
188            .initiate_login_uri
189            .map(|s| s.parse())
190            .transpose()
191            .map_err(|e| {
192                DatabaseInconsistencyError::on("oauth2_clients")
193                    .column("initiate_login_uri")
194                    .row(id)
195                    .source(e)
196            })?;
197
198        let jwks = match (self.jwks, self.jwks_uri) {
199            (None, None) => None,
200            (Some(jwks), None) => {
201                let jwks = serde_json::from_value(jwks).map_err(|e| {
202                    DatabaseInconsistencyError::on("oauth2_clients")
203                        .column("jwks")
204                        .row(id)
205                        .source(e)
206                })?;
207                Some(JwksOrJwksUri::Jwks(jwks))
208            }
209            (None, Some(jwks_uri)) => {
210                let jwks_uri = jwks_uri.parse().map_err(|e| {
211                    DatabaseInconsistencyError::on("oauth2_clients")
212                        .column("jwks_uri")
213                        .row(id)
214                        .source(e)
215                })?;
216
217                Some(JwksOrJwksUri::JwksUri(jwks_uri))
218            }
219            _ => {
220                return Err(DatabaseInconsistencyError::on("oauth2_clients")
221                    .column("jwks(_uri)")
222                    .row(id));
223            }
224        };
225
226        Ok(Client {
227            id,
228            client_id: id.to_string(),
229            metadata_digest: self.metadata_digest,
230            encrypted_client_secret: self.encrypted_client_secret,
231            application_type,
232            redirect_uris,
233            grant_types,
234            client_name: self.client_name,
235            logo_uri,
236            client_uri,
237            policy_uri,
238            tos_uri,
239            jwks,
240            id_token_signed_response_alg,
241            userinfo_signed_response_alg,
242            token_endpoint_auth_method,
243            token_endpoint_auth_signing_alg,
244            initiate_login_uri,
245        })
246    }
247}
248
249#[async_trait]
250impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
251    type Error = DatabaseError;
252
253    #[tracing::instrument(
254        name = "db.oauth2_client.lookup",
255        skip_all,
256        fields(
257            db.query.text,
258            oauth2_client.id = %id,
259        ),
260        err,
261    )]
262    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
263        let res = sqlx::query_as!(
264            OAuth2ClientLookup,
265            r#"
266                SELECT oauth2_client_id
267                     , metadata_digest
268                     , encrypted_client_secret
269                     , application_type
270                     , redirect_uris
271                     , grant_type_authorization_code
272                     , grant_type_refresh_token
273                     , grant_type_client_credentials
274                     , grant_type_device_code
275                     , client_name
276                     , logo_uri
277                     , client_uri
278                     , policy_uri
279                     , tos_uri
280                     , jwks_uri
281                     , jwks
282                     , id_token_signed_response_alg
283                     , userinfo_signed_response_alg
284                     , token_endpoint_auth_method
285                     , token_endpoint_auth_signing_alg
286                     , initiate_login_uri
287                FROM oauth2_clients c
288
289                WHERE oauth2_client_id = $1
290            "#,
291            Uuid::from(id),
292        )
293        .traced()
294        .fetch_optional(&mut *self.conn)
295        .await?;
296
297        let Some(res) = res else { return Ok(None) };
298
299        Ok(Some(res.try_into()?))
300    }
301
302    #[tracing::instrument(
303        name = "db.oauth2_client.find_by_metadata_digest",
304        skip_all,
305        fields(
306            db.query.text,
307        ),
308        err,
309    )]
310    async fn find_by_metadata_digest(
311        &mut self,
312        digest: &str,
313    ) -> Result<Option<Client>, Self::Error> {
314        let res = sqlx::query_as!(
315            OAuth2ClientLookup,
316            r#"
317                SELECT oauth2_client_id
318                    , metadata_digest
319                    , encrypted_client_secret
320                    , application_type
321                    , redirect_uris
322                    , grant_type_authorization_code
323                    , grant_type_refresh_token
324                    , grant_type_client_credentials
325                    , grant_type_device_code
326                    , client_name
327                    , logo_uri
328                    , client_uri
329                    , policy_uri
330                    , tos_uri
331                    , jwks_uri
332                    , jwks
333                    , id_token_signed_response_alg
334                    , userinfo_signed_response_alg
335                    , token_endpoint_auth_method
336                    , token_endpoint_auth_signing_alg
337                    , initiate_login_uri
338                FROM oauth2_clients
339                WHERE metadata_digest = $1
340            "#,
341            digest,
342        )
343        .traced()
344        .fetch_optional(&mut *self.conn)
345        .await?;
346
347        let Some(res) = res else { return Ok(None) };
348
349        Ok(Some(res.try_into()?))
350    }
351
352    #[tracing::instrument(
353        name = "db.oauth2_client.load_batch",
354        skip_all,
355        fields(
356            db.query.text,
357        ),
358        err,
359    )]
360    async fn load_batch(
361        &mut self,
362        ids: BTreeSet<Ulid>,
363    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
364        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
365        let res = sqlx::query_as!(
366            OAuth2ClientLookup,
367            r#"
368                SELECT oauth2_client_id
369                     , metadata_digest
370                     , encrypted_client_secret
371                     , application_type
372                     , redirect_uris
373                     , grant_type_authorization_code
374                     , grant_type_refresh_token
375                     , grant_type_client_credentials
376                     , grant_type_device_code
377                     , client_name
378                     , logo_uri
379                     , client_uri
380                     , policy_uri
381                     , tos_uri
382                     , jwks_uri
383                     , jwks
384                     , id_token_signed_response_alg
385                     , userinfo_signed_response_alg
386                     , token_endpoint_auth_method
387                     , token_endpoint_auth_signing_alg
388                     , initiate_login_uri
389                FROM oauth2_clients c
390
391                WHERE oauth2_client_id = ANY($1::uuid[])
392            "#,
393            &ids,
394        )
395        .traced()
396        .fetch_all(&mut *self.conn)
397        .await?;
398
399        res.into_iter()
400            .map(|r| {
401                r.try_into()
402                    .map(|c: Client| (c.id, c))
403                    .map_err(DatabaseError::from)
404            })
405            .collect()
406    }
407
408    #[tracing::instrument(
409        name = "db.oauth2_client.add",
410        skip_all,
411        fields(
412            db.query.text,
413            client.id,
414            client.name = client_name
415        ),
416        err,
417    )]
418    async fn add(
419        &mut self,
420        rng: &mut (dyn RngCore + Send),
421        clock: &dyn Clock,
422        redirect_uris: Vec<Url>,
423        metadata_digest: Option<String>,
424        encrypted_client_secret: Option<String>,
425        application_type: Option<ApplicationType>,
426        grant_types: Vec<GrantType>,
427        client_name: Option<String>,
428        logo_uri: Option<Url>,
429        client_uri: Option<Url>,
430        policy_uri: Option<Url>,
431        tos_uri: Option<Url>,
432        jwks_uri: Option<Url>,
433        jwks: Option<PublicJsonWebKeySet>,
434        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
435        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
436        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
437        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
438        initiate_login_uri: Option<Url>,
439    ) -> Result<Client, Self::Error> {
440        let now = clock.now();
441        let id = Ulid::from_datetime_with_source(now.into(), rng);
442        tracing::Span::current().record("client.id", tracing::field::display(id));
443
444        let jwks_json = jwks
445            .as_ref()
446            .map(serde_json::to_value)
447            .transpose()
448            .map_err(DatabaseError::to_invalid_operation)?;
449
450        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
451
452        sqlx::query!(
453            r#"
454                INSERT INTO oauth2_clients
455                    ( oauth2_client_id
456                    , metadata_digest
457                    , encrypted_client_secret
458                    , application_type
459                    , redirect_uris
460                    , grant_type_authorization_code
461                    , grant_type_refresh_token
462                    , grant_type_client_credentials
463                    , grant_type_device_code
464                    , client_name
465                    , logo_uri
466                    , client_uri
467                    , policy_uri
468                    , tos_uri
469                    , jwks_uri
470                    , jwks
471                    , id_token_signed_response_alg
472                    , userinfo_signed_response_alg
473                    , token_endpoint_auth_method
474                    , token_endpoint_auth_signing_alg
475                    , initiate_login_uri
476                    , is_static
477                    )
478                VALUES
479                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
480                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
481            "#,
482            Uuid::from(id),
483            metadata_digest,
484            encrypted_client_secret,
485            application_type.as_ref().map(ToString::to_string),
486            &redirect_uris_array,
487            grant_types.contains(&GrantType::AuthorizationCode),
488            grant_types.contains(&GrantType::RefreshToken),
489            grant_types.contains(&GrantType::ClientCredentials),
490            grant_types.contains(&GrantType::DeviceCode),
491            client_name,
492            logo_uri.as_ref().map(Url::as_str),
493            client_uri.as_ref().map(Url::as_str),
494            policy_uri.as_ref().map(Url::as_str),
495            tos_uri.as_ref().map(Url::as_str),
496            jwks_uri.as_ref().map(Url::as_str),
497            jwks_json,
498            id_token_signed_response_alg
499                .as_ref()
500                .map(ToString::to_string),
501            userinfo_signed_response_alg
502                .as_ref()
503                .map(ToString::to_string),
504            token_endpoint_auth_method.as_ref().map(ToString::to_string),
505            token_endpoint_auth_signing_alg
506                .as_ref()
507                .map(ToString::to_string),
508            initiate_login_uri.as_ref().map(Url::as_str),
509        )
510        .traced()
511        .execute(&mut *self.conn)
512        .await?;
513
514        let jwks = match (jwks, jwks_uri) {
515            (None, None) => None,
516            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
517            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
518            _ => return Err(DatabaseError::invalid_operation()),
519        };
520
521        Ok(Client {
522            id,
523            client_id: id.to_string(),
524            metadata_digest: None,
525            encrypted_client_secret,
526            application_type,
527            redirect_uris,
528            grant_types,
529            client_name,
530            logo_uri,
531            client_uri,
532            policy_uri,
533            tos_uri,
534            jwks,
535            id_token_signed_response_alg,
536            userinfo_signed_response_alg,
537            token_endpoint_auth_method,
538            token_endpoint_auth_signing_alg,
539            initiate_login_uri,
540        })
541    }
542
543    #[tracing::instrument(
544        name = "db.oauth2_client.upsert_static",
545        skip_all,
546        fields(
547            db.query.text,
548            client.id = %client_id,
549        ),
550        err,
551    )]
552    async fn upsert_static(
553        &mut self,
554        client_id: Ulid,
555        client_name: Option<String>,
556        client_auth_method: OAuthClientAuthenticationMethod,
557        encrypted_client_secret: Option<String>,
558        jwks: Option<PublicJsonWebKeySet>,
559        jwks_uri: Option<Url>,
560        redirect_uris: Vec<Url>,
561    ) -> Result<Client, Self::Error> {
562        let jwks_json = jwks
563            .as_ref()
564            .map(serde_json::to_value)
565            .transpose()
566            .map_err(DatabaseError::to_invalid_operation)?;
567
568        let client_auth_method = client_auth_method.to_string();
569        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
570
571        sqlx::query!(
572            r#"
573                INSERT INTO oauth2_clients
574                    ( oauth2_client_id
575                    , encrypted_client_secret
576                    , redirect_uris
577                    , grant_type_authorization_code
578                    , grant_type_refresh_token
579                    , grant_type_client_credentials
580                    , grant_type_device_code
581                    , token_endpoint_auth_method
582                    , jwks
583                    , client_name
584                    , jwks_uri
585                    , is_static
586                    )
587                VALUES
588                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
589                ON CONFLICT (oauth2_client_id)
590                DO
591                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
592                             , redirect_uris = EXCLUDED.redirect_uris
593                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
594                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
595                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
596                             , grant_type_device_code = EXCLUDED.grant_type_device_code
597                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
598                             , jwks = EXCLUDED.jwks
599                             , client_name = EXCLUDED.client_name
600                             , jwks_uri = EXCLUDED.jwks_uri
601                             , is_static = TRUE
602            "#,
603            Uuid::from(client_id),
604            encrypted_client_secret,
605            &redirect_uris_array,
606            true,
607            true,
608            true,
609            true,
610            client_auth_method,
611            jwks_json,
612            client_name,
613            jwks_uri.as_ref().map(Url::as_str),
614        )
615        .traced()
616        .execute(&mut *self.conn)
617        .await?;
618
619        let jwks = match (jwks, jwks_uri) {
620            (None, None) => None,
621            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
622            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
623            _ => return Err(DatabaseError::invalid_operation()),
624        };
625
626        Ok(Client {
627            id: client_id,
628            client_id: client_id.to_string(),
629            metadata_digest: None,
630            encrypted_client_secret,
631            application_type: None,
632            redirect_uris,
633            grant_types: vec![
634                GrantType::AuthorizationCode,
635                GrantType::RefreshToken,
636                GrantType::ClientCredentials,
637            ],
638            client_name,
639            logo_uri: None,
640            client_uri: None,
641            policy_uri: None,
642            tos_uri: None,
643            jwks,
644            id_token_signed_response_alg: None,
645            userinfo_signed_response_alg: None,
646            token_endpoint_auth_method: None,
647            token_endpoint_auth_signing_alg: None,
648            initiate_login_uri: None,
649        })
650    }
651
652    #[tracing::instrument(
653        name = "db.oauth2_client.all_static",
654        skip_all,
655        fields(
656            db.query.text,
657        ),
658        err,
659    )]
660    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
661        let res = sqlx::query_as!(
662            OAuth2ClientLookup,
663            r#"
664                SELECT oauth2_client_id
665                     , metadata_digest
666                     , encrypted_client_secret
667                     , application_type
668                     , redirect_uris
669                     , grant_type_authorization_code
670                     , grant_type_refresh_token
671                     , grant_type_client_credentials
672                     , grant_type_device_code
673                     , client_name
674                     , logo_uri
675                     , client_uri
676                     , policy_uri
677                     , tos_uri
678                     , jwks_uri
679                     , jwks
680                     , id_token_signed_response_alg
681                     , userinfo_signed_response_alg
682                     , token_endpoint_auth_method
683                     , token_endpoint_auth_signing_alg
684                     , initiate_login_uri
685                FROM oauth2_clients c
686                WHERE is_static = TRUE
687            "#,
688        )
689        .traced()
690        .fetch_all(&mut *self.conn)
691        .await?;
692
693        res.into_iter()
694            .map(|r| r.try_into().map_err(DatabaseError::from))
695            .collect()
696    }
697
698    #[tracing::instrument(
699        name = "db.oauth2_client.delete_by_id",
700        skip_all,
701        fields(
702            db.query.text,
703            client.id = %id,
704        ),
705        err,
706    )]
707    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
708        // Delete the authorization grants
709        {
710            let span = info_span!(
711                "db.oauth2_client.delete_by_id.authorization_grants",
712                { DB_QUERY_TEXT } = tracing::field::Empty,
713            );
714
715            sqlx::query!(
716                r#"
717                    DELETE FROM oauth2_authorization_grants
718                    WHERE oauth2_client_id = $1
719                "#,
720                Uuid::from(id),
721            )
722            .record(&span)
723            .execute(&mut *self.conn)
724            .instrument(span)
725            .await?;
726        }
727
728        // Delete the user consents
729        {
730            let span = info_span!(
731                "db.oauth2_client.delete_by_id.consents",
732                { DB_QUERY_TEXT } = tracing::field::Empty,
733            );
734
735            sqlx::query!(
736                r#"
737                    DELETE FROM oauth2_consents
738                    WHERE oauth2_client_id = $1
739                "#,
740                Uuid::from(id),
741            )
742            .record(&span)
743            .execute(&mut *self.conn)
744            .instrument(span)
745            .await?;
746        }
747
748        // Delete the OAuth 2 sessions related data
749        {
750            let span = info_span!(
751                "db.oauth2_client.delete_by_id.access_tokens",
752                { DB_QUERY_TEXT } = tracing::field::Empty,
753            );
754
755            sqlx::query!(
756                r#"
757                    DELETE FROM oauth2_access_tokens
758                    WHERE oauth2_session_id IN (
759                        SELECT oauth2_session_id
760                        FROM oauth2_sessions
761                        WHERE oauth2_client_id = $1
762                    )
763                "#,
764                Uuid::from(id),
765            )
766            .record(&span)
767            .execute(&mut *self.conn)
768            .instrument(span)
769            .await?;
770        }
771
772        {
773            let span = info_span!(
774                "db.oauth2_client.delete_by_id.refresh_tokens",
775                { DB_QUERY_TEXT } = tracing::field::Empty,
776            );
777
778            sqlx::query!(
779                r#"
780                    DELETE FROM oauth2_refresh_tokens
781                    WHERE oauth2_session_id IN (
782                        SELECT oauth2_session_id
783                        FROM oauth2_sessions
784                        WHERE oauth2_client_id = $1
785                    )
786                "#,
787                Uuid::from(id),
788            )
789            .record(&span)
790            .execute(&mut *self.conn)
791            .instrument(span)
792            .await?;
793        }
794
795        {
796            let span = info_span!(
797                "db.oauth2_client.delete_by_id.sessions",
798                { DB_QUERY_TEXT } = tracing::field::Empty,
799            );
800
801            sqlx::query!(
802                r#"
803                    DELETE FROM oauth2_sessions
804                    WHERE oauth2_client_id = $1
805                "#,
806                Uuid::from(id),
807            )
808            .record(&span)
809            .execute(&mut *self.conn)
810            .instrument(span)
811            .await?;
812        }
813
814        // Now delete the client itself
815        let res = sqlx::query!(
816            r#"
817                DELETE FROM oauth2_clients
818                WHERE oauth2_client_id = $1
819            "#,
820            Uuid::from(id),
821        )
822        .traced()
823        .execute(&mut *self.conn)
824        .await?;
825
826        DatabaseError::ensure_affected_rows(&res, 1)
827    }
828}