mas_storage_pg/oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13    Clock, Page, Pagination,
14    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{
19    Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
20    extension::postgres::PgExpr,
21};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::{Filter, StatementExt},
30    iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
36pub struct PgOAuth2SessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgOAuth2SessionRepository<'c> {
41    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct OAuthSessionLookup {
51    oauth2_session_id: Uuid,
52    user_id: Option<Uuid>,
53    user_session_id: Option<Uuid>,
54    oauth2_client_id: Uuid,
55    scope_list: Vec<String>,
56    created_at: DateTime<Utc>,
57    finished_at: Option<DateTime<Utc>>,
58    user_agent: Option<String>,
59    last_active_at: Option<DateTime<Utc>>,
60    last_active_ip: Option<IpAddr>,
61    human_name: Option<String>,
62}
63
64impl TryFrom<OAuthSessionLookup> for Session {
65    type Error = DatabaseInconsistencyError;
66
67    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
68        let id = Ulid::from(value.oauth2_session_id);
69        let scope: Result<Scope, _> = value
70            .scope_list
71            .iter()
72            .map(|s| s.parse::<ScopeToken>())
73            .collect();
74        let scope = scope.map_err(|e| {
75            DatabaseInconsistencyError::on("oauth2_sessions")
76                .column("scope")
77                .row(id)
78                .source(e)
79        })?;
80
81        let state = match value.finished_at {
82            None => SessionState::Valid,
83            Some(finished_at) => SessionState::Finished { finished_at },
84        };
85
86        Ok(Session {
87            id,
88            state,
89            created_at: value.created_at,
90            client_id: value.oauth2_client_id.into(),
91            user_id: value.user_id.map(Ulid::from),
92            user_session_id: value.user_session_id.map(Ulid::from),
93            scope,
94            user_agent: value.user_agent,
95            last_active_at: value.last_active_at,
96            last_active_ip: value.last_active_ip,
97            human_name: value.human_name,
98        })
99    }
100}
101
102impl Filter for OAuth2SessionFilter<'_> {
103    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
104        sea_query::Condition::all()
105            .add_option(self.user().map(|user| {
106                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
107            }))
108            .add_option(self.client().map(|client| {
109                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
110                    .eq(Uuid::from(client.id))
111            }))
112            .add_option(self.client_kind().map(|client_kind| {
113                // This builds either a:
114                // `WHERE oauth2_client_id = ANY(...)`
115                // or a `WHERE oauth2_client_id <> ALL(...)`
116                let static_clients = Query::select()
117                    .expr(Expr::col((
118                        OAuth2Clients::Table,
119                        OAuth2Clients::OAuth2ClientId,
120                    )))
121                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
122                    .from(OAuth2Clients::Table)
123                    .take();
124                if client_kind.is_static() {
125                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126                        .eq(Expr::any(static_clients))
127                } else {
128                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
129                        .ne(Expr::all(static_clients))
130                }
131            }))
132            .add_option(self.device().map(|device| -> SimpleExpr {
133                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
134                    Condition::any()
135                        .add(
136                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
137                                OAuth2Sessions::Table,
138                                OAuth2Sessions::ScopeList,
139                            )))),
140                        )
141                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
142                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
143                        )))
144                        .into()
145                } else {
146                    // If the device ID can't be encoded as a scope token, match no rows
147                    Expr::val(false).into()
148                }
149            }))
150            .add_option(self.browser_session().map(|browser_session| {
151                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
152                    .eq(Uuid::from(browser_session.id))
153            }))
154            .add_option(self.browser_session_filter().map(|browser_session_filter| {
155                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
156                    Query::select()
157                        .expr(Expr::col((
158                            UserSessions::Table,
159                            UserSessions::UserSessionId,
160                        )))
161                        .apply_filter(browser_session_filter)
162                        .from(UserSessions::Table)
163                        .take(),
164                )
165            }))
166            .add_option(self.state().map(|state| {
167                if state.is_active() {
168                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
169                } else {
170                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
171                }
172            }))
173            .add_option(self.scope().map(|scope| {
174                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
175                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
176            }))
177            .add_option(self.any_user().map(|any_user| {
178                if any_user {
179                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
180                } else {
181                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
182                }
183            }))
184            .add_option(self.last_active_after().map(|last_active_after| {
185                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
186                    .gt(last_active_after)
187            }))
188            .add_option(self.last_active_before().map(|last_active_before| {
189                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
190                    .lt(last_active_before)
191            }))
192    }
193}
194
195#[async_trait]
196impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
197    type Error = DatabaseError;
198
199    #[tracing::instrument(
200        name = "db.oauth2_session.lookup",
201        skip_all,
202        fields(
203            db.query.text,
204            session.id = %id,
205        ),
206        err,
207    )]
208    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
209        let res = sqlx::query_as!(
210            OAuthSessionLookup,
211            r#"
212                SELECT oauth2_session_id
213                     , user_id
214                     , user_session_id
215                     , oauth2_client_id
216                     , scope_list
217                     , created_at
218                     , finished_at
219                     , user_agent
220                     , last_active_at
221                     , last_active_ip as "last_active_ip: IpAddr"
222                     , human_name
223                FROM oauth2_sessions
224
225                WHERE oauth2_session_id = $1
226            "#,
227            Uuid::from(id),
228        )
229        .traced()
230        .fetch_optional(&mut *self.conn)
231        .await?;
232
233        let Some(session) = res else { return Ok(None) };
234
235        Ok(Some(session.try_into()?))
236    }
237
238    #[tracing::instrument(
239        name = "db.oauth2_session.add",
240        skip_all,
241        fields(
242            db.query.text,
243            %client.id,
244            session.id,
245            session.scope = %scope,
246        ),
247        err,
248    )]
249    async fn add(
250        &mut self,
251        rng: &mut (dyn RngCore + Send),
252        clock: &dyn Clock,
253        client: &Client,
254        user: Option<&User>,
255        user_session: Option<&BrowserSession>,
256        scope: Scope,
257    ) -> Result<Session, Self::Error> {
258        let created_at = clock.now();
259        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
260        tracing::Span::current().record("session.id", tracing::field::display(id));
261
262        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
263
264        sqlx::query!(
265            r#"
266                INSERT INTO oauth2_sessions
267                    ( oauth2_session_id
268                    , user_id
269                    , user_session_id
270                    , oauth2_client_id
271                    , scope_list
272                    , created_at
273                    )
274                VALUES ($1, $2, $3, $4, $5, $6)
275            "#,
276            Uuid::from(id),
277            user.map(|u| Uuid::from(u.id)),
278            user_session.map(|s| Uuid::from(s.id)),
279            Uuid::from(client.id),
280            &scope_list,
281            created_at,
282        )
283        .traced()
284        .execute(&mut *self.conn)
285        .await?;
286
287        Ok(Session {
288            id,
289            state: SessionState::Valid,
290            created_at,
291            user_id: user.map(|u| u.id),
292            user_session_id: user_session.map(|s| s.id),
293            client_id: client.id,
294            scope,
295            user_agent: None,
296            last_active_at: None,
297            last_active_ip: None,
298            human_name: None,
299        })
300    }
301
302    #[tracing::instrument(
303        name = "db.oauth2_session.finish_bulk",
304        skip_all,
305        fields(
306            db.query.text,
307        ),
308        err,
309    )]
310    async fn finish_bulk(
311        &mut self,
312        clock: &dyn Clock,
313        filter: OAuth2SessionFilter<'_>,
314    ) -> Result<usize, Self::Error> {
315        let finished_at = clock.now();
316        let (sql, arguments) = Query::update()
317            .table(OAuth2Sessions::Table)
318            .value(OAuth2Sessions::FinishedAt, finished_at)
319            .apply_filter(filter)
320            .build_sqlx(PostgresQueryBuilder);
321
322        let res = sqlx::query_with(&sql, arguments)
323            .traced()
324            .execute(&mut *self.conn)
325            .await?;
326
327        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
328    }
329
330    #[tracing::instrument(
331        name = "db.oauth2_session.finish",
332        skip_all,
333        fields(
334            db.query.text,
335            %session.id,
336            %session.scope,
337            client.id = %session.client_id,
338        ),
339        err,
340    )]
341    async fn finish(
342        &mut self,
343        clock: &dyn Clock,
344        session: Session,
345    ) -> Result<Session, Self::Error> {
346        let finished_at = clock.now();
347        let res = sqlx::query!(
348            r#"
349                UPDATE oauth2_sessions
350                SET finished_at = $2
351                WHERE oauth2_session_id = $1
352            "#,
353            Uuid::from(session.id),
354            finished_at,
355        )
356        .traced()
357        .execute(&mut *self.conn)
358        .await?;
359
360        DatabaseError::ensure_affected_rows(&res, 1)?;
361
362        session
363            .finish(finished_at)
364            .map_err(DatabaseError::to_invalid_operation)
365    }
366
367    #[tracing::instrument(
368        name = "db.oauth2_session.list",
369        skip_all,
370        fields(
371            db.query.text,
372        ),
373        err,
374    )]
375    async fn list(
376        &mut self,
377        filter: OAuth2SessionFilter<'_>,
378        pagination: Pagination,
379    ) -> Result<Page<Session>, Self::Error> {
380        let (sql, arguments) = Query::select()
381            .expr_as(
382                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
383                OAuthSessionLookupIden::Oauth2SessionId,
384            )
385            .expr_as(
386                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
387                OAuthSessionLookupIden::UserId,
388            )
389            .expr_as(
390                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
391                OAuthSessionLookupIden::UserSessionId,
392            )
393            .expr_as(
394                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
395                OAuthSessionLookupIden::Oauth2ClientId,
396            )
397            .expr_as(
398                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
399                OAuthSessionLookupIden::ScopeList,
400            )
401            .expr_as(
402                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
403                OAuthSessionLookupIden::CreatedAt,
404            )
405            .expr_as(
406                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
407                OAuthSessionLookupIden::FinishedAt,
408            )
409            .expr_as(
410                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
411                OAuthSessionLookupIden::UserAgent,
412            )
413            .expr_as(
414                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
415                OAuthSessionLookupIden::LastActiveAt,
416            )
417            .expr_as(
418                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
419                OAuthSessionLookupIden::LastActiveIp,
420            )
421            .expr_as(
422                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
423                OAuthSessionLookupIden::HumanName,
424            )
425            .from(OAuth2Sessions::Table)
426            .apply_filter(filter)
427            .generate_pagination(
428                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
429                pagination,
430            )
431            .build_sqlx(PostgresQueryBuilder);
432
433        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
434            .traced()
435            .fetch_all(&mut *self.conn)
436            .await?;
437
438        let page = pagination.process(edges).try_map(Session::try_from)?;
439
440        Ok(page)
441    }
442
443    #[tracing::instrument(
444        name = "db.oauth2_session.count",
445        skip_all,
446        fields(
447            db.query.text,
448        ),
449        err,
450    )]
451    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
452        let (sql, arguments) = Query::select()
453            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
454            .from(OAuth2Sessions::Table)
455            .apply_filter(filter)
456            .build_sqlx(PostgresQueryBuilder);
457
458        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
459            .traced()
460            .fetch_one(&mut *self.conn)
461            .await?;
462
463        count
464            .try_into()
465            .map_err(DatabaseError::to_invalid_operation)
466    }
467
468    #[tracing::instrument(
469        name = "db.oauth2_session.record_batch_activity",
470        skip_all,
471        fields(
472            db.query.text,
473        ),
474        err,
475    )]
476    async fn record_batch_activity(
477        &mut self,
478        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
479    ) -> Result<(), Self::Error> {
480        // Sort the activity by ID, so that when batching the updates, Postgres
481        // locks the rows in a stable order, preventing deadlocks
482        activities.sort_unstable();
483        let mut ids = Vec::with_capacity(activities.len());
484        let mut last_activities = Vec::with_capacity(activities.len());
485        let mut ips = Vec::with_capacity(activities.len());
486
487        for (id, last_activity, ip) in activities {
488            ids.push(Uuid::from(id));
489            last_activities.push(last_activity);
490            ips.push(ip);
491        }
492
493        let res = sqlx::query!(
494            r#"
495                UPDATE oauth2_sessions
496                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
497                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
498                FROM (
499                    SELECT *
500                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
501                        AS t(oauth2_session_id, last_active_at, last_active_ip)
502                ) AS t
503                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
504            "#,
505            &ids,
506            &last_activities,
507            &ips as &[Option<IpAddr>],
508        )
509        .traced()
510        .execute(&mut *self.conn)
511        .await?;
512
513        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
514
515        Ok(())
516    }
517
518    #[tracing::instrument(
519        name = "db.oauth2_session.record_user_agent",
520        skip_all,
521        fields(
522            db.query.text,
523            %session.id,
524            %session.scope,
525            client.id = %session.client_id,
526            session.user_agent = user_agent,
527        ),
528        err,
529    )]
530    async fn record_user_agent(
531        &mut self,
532        mut session: Session,
533        user_agent: String,
534    ) -> Result<Session, Self::Error> {
535        let res = sqlx::query!(
536            r#"
537                UPDATE oauth2_sessions
538                SET user_agent = $2
539                WHERE oauth2_session_id = $1
540            "#,
541            Uuid::from(session.id),
542            &*user_agent,
543        )
544        .traced()
545        .execute(&mut *self.conn)
546        .await?;
547
548        session.user_agent = Some(user_agent);
549
550        DatabaseError::ensure_affected_rows(&res, 1)?;
551
552        Ok(session)
553    }
554
555    #[tracing::instrument(
556        name = "repository.oauth2_session.set_human_name",
557        skip(self),
558        fields(
559            client.id = %session.client_id,
560            session.human_name = ?human_name,
561        ),
562        err,
563    )]
564    async fn set_human_name(
565        &mut self,
566        mut session: Session,
567        human_name: Option<String>,
568    ) -> Result<Session, Self::Error> {
569        let res = sqlx::query!(
570            r#"
571                UPDATE oauth2_sessions
572                SET human_name = $2
573                WHERE oauth2_session_id = $1
574            "#,
575            Uuid::from(session.id),
576            human_name.as_deref(),
577        )
578        .traced()
579        .execute(&mut *self.conn)
580        .await?;
581
582        session.human_name = human_name;
583
584        DatabaseError::ensure_affected_rows(&res, 1)?;
585
586        Ok(session)
587    }
588}