mas_storage/upstream_oauth2/session.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
9use rand_core::RngCore;
10use ulid::Ulid;
11
12use crate::{Clock, repository_impl};
13
14/// An [`UpstreamOAuthSessionRepository`] helps interacting with
15/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
16#[async_trait]
17pub trait UpstreamOAuthSessionRepository: Send + Sync {
18    /// The error type returned by the repository
19    type Error;
20
21    /// Lookup a session by its ID
22    ///
23    /// Returns `None` if the session does not exist
24    ///
25    /// # Parameters
26    ///
27    /// * `id`: the ID of the session to lookup
28    ///
29    /// # Errors
30    ///
31    /// Returns [`Self::Error`] if the underlying repository fails
32    async fn lookup(
33        &mut self,
34        id: Ulid,
35    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
36
37    /// Add a session to the database
38    ///
39    /// Returns the newly created session
40    ///
41    /// # Parameters
42    ///
43    /// * `rng`: the random number generator to use
44    /// * `clock`: the clock source
45    /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
46    ///   create the session
47    /// * `state`: the authorization grant `state` parameter sent to the
48    ///   upstream OAuth provider
49    /// * `code_challenge_verifier`: the code challenge verifier used in this
50    ///   session, if PKCE is being used
51    /// * `nonce`: the `nonce` used in this session
52    ///
53    /// # Errors
54    ///
55    /// Returns [`Self::Error`] if the underlying repository fails
56    async fn add(
57        &mut self,
58        rng: &mut (dyn RngCore + Send),
59        clock: &dyn Clock,
60        upstream_oauth_provider: &UpstreamOAuthProvider,
61        state: String,
62        code_challenge_verifier: Option<String>,
63        nonce: String,
64    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
65
66    /// Mark a session as completed and associate the given link
67    ///
68    /// Returns the updated session
69    ///
70    /// # Parameters
71    ///
72    /// * `clock`: the clock source
73    /// * `upstream_oauth_authorization_session`: the session to update
74    /// * `upstream_oauth_link`: the link to associate with the session
75    /// * `id_token`: the ID token returned by the upstream OAuth provider, if
76    ///   present
77    /// * `extra_callback_parameters`: the extra query parameters returned in
78    ///   the callback, if any
79    ///
80    /// # Errors
81    ///
82    /// Returns [`Self::Error`] if the underlying repository fails
83    async fn complete_with_link(
84        &mut self,
85        clock: &dyn Clock,
86        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
87        upstream_oauth_link: &UpstreamOAuthLink,
88        id_token: Option<String>,
89        extra_callback_parameters: Option<serde_json::Value>,
90        userinfo: Option<serde_json::Value>,
91    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
92
93    /// Mark a session as consumed
94    ///
95    /// Returns the updated session
96    ///
97    /// # Parameters
98    ///
99    /// * `clock`: the clock source
100    /// * `upstream_oauth_authorization_session`: the session to consume
101    ///
102    /// # Errors
103    ///
104    /// Returns [`Self::Error`] if the underlying repository fails
105    async fn consume(
106        &mut self,
107        clock: &dyn Clock,
108        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
109    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
110}
111
112repository_impl!(UpstreamOAuthSessionRepository:
113    async fn lookup(
114        &mut self,
115        id: Ulid,
116    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
117
118    async fn add(
119        &mut self,
120        rng: &mut (dyn RngCore + Send),
121        clock: &dyn Clock,
122        upstream_oauth_provider: &UpstreamOAuthProvider,
123        state: String,
124        code_challenge_verifier: Option<String>,
125        nonce: String,
126    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
127
128    async fn complete_with_link(
129        &mut self,
130        clock: &dyn Clock,
131        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
132        upstream_oauth_link: &UpstreamOAuthLink,
133        id_token: Option<String>,
134        extra_callback_parameters: Option<serde_json::Value>,
135        userinfo: Option<serde_json::Value>,
136    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
137
138    async fn consume(
139        &mut self,
140        clock: &dyn Clock,
141        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
142    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
143);