sqlx_core/migrate/
migrator.rs

1use crate::acquire::Acquire;
2use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource};
3use std::borrow::Cow;
4use std::collections::{HashMap, HashSet};
5use std::ops::Deref;
6use std::slice;
7
8/// A resolved set of migrations, ready to be run.
9///
10/// Can be constructed statically using `migrate!()` or at runtime using [`Migrator::new()`].
11#[derive(Debug)]
12// Forbids `migrate!()` from constructing this:
13// #[non_exhaustive]
14pub struct Migrator {
15    // NOTE: these fields are semver-exempt and may be changed or removed in any future version.
16    // These have to be public for `migrate!()` to be able to initialize them in an implicitly
17    // const-promotable context. A `const fn` constructor isn't implicitly const-promotable.
18    #[doc(hidden)]
19    pub migrations: Cow<'static, [Migration]>,
20    #[doc(hidden)]
21    pub ignore_missing: bool,
22    #[doc(hidden)]
23    pub locking: bool,
24    #[doc(hidden)]
25    pub no_tx: bool,
26}
27
28fn validate_applied_migrations(
29    applied_migrations: &[AppliedMigration],
30    migrator: &Migrator,
31) -> Result<(), MigrateError> {
32    if migrator.ignore_missing {
33        return Ok(());
34    }
35
36    let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
37
38    for applied_migration in applied_migrations {
39        if !migrations.contains(&applied_migration.version) {
40            return Err(MigrateError::VersionMissing(applied_migration.version));
41        }
42    }
43
44    Ok(())
45}
46
47impl Migrator {
48    #[doc(hidden)]
49    pub const DEFAULT: Migrator = Migrator {
50        migrations: Cow::Borrowed(&[]),
51        ignore_missing: false,
52        no_tx: false,
53        locking: true,
54    };
55
56    /// Creates a new instance with the given source.
57    ///
58    /// # Examples
59    ///
60    /// ```rust,no_run
61    /// # use sqlx_core::migrate::MigrateError;
62    /// # fn main() -> Result<(), MigrateError> {
63    /// # sqlx::__rt::test_block_on(async move {
64    /// # use sqlx_core::migrate::Migrator;
65    /// use std::path::Path;
66    ///
67    /// // Read migrations from a local folder: ./migrations
68    /// let m = Migrator::new(Path::new("./migrations")).await?;
69    /// # Ok(())
70    /// # })
71    /// # }
72    /// ```
73    /// See [MigrationSource] for details on structure of the `./migrations` directory.
74    pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
75    where
76        S: MigrationSource<'s>,
77    {
78        Ok(Self {
79            migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
80            ..Self::DEFAULT
81        })
82    }
83
84    /// Specify whether applied migrations that are missing from the resolved migrations should be ignored.
85    pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
86        self.ignore_missing = ignore_missing;
87        self
88    }
89
90    /// Specify whether or not to lock the database during migration. Defaults to `true`.
91    ///
92    /// ### Warning
93    /// Disabling locking can lead to errors or data loss if multiple clients attempt to apply migrations simultaneously
94    /// without some sort of mutual exclusion.
95    ///
96    /// This should only be used if the database does not support locking, e.g. CockroachDB which talks the Postgres
97    /// protocol but does not support advisory locks used by SQLx's migrations support for Postgres.
98    pub fn set_locking(&mut self, locking: bool) -> &Self {
99        self.locking = locking;
100        self
101    }
102
103    /// Get an iterator over all known migrations.
104    pub fn iter(&self) -> slice::Iter<'_, Migration> {
105        self.migrations.iter()
106    }
107
108    /// Check if a migration version exists.
109    pub fn version_exists(&self, version: i64) -> bool {
110        self.iter().any(|m| m.version == version)
111    }
112
113    /// Run any pending migrations against the database; and, validate previously applied migrations
114    /// against the current migration source to detect accidental changes in previously-applied migrations.
115    ///
116    /// # Examples
117    ///
118    /// ```rust,no_run
119    /// # use sqlx::migrate::MigrateError;
120    /// # fn main() -> Result<(), MigrateError> {
121    /// #     sqlx::__rt::test_block_on(async move {
122    /// use sqlx::migrate::Migrator;
123    /// use sqlx::sqlite::SqlitePoolOptions;
124    ///
125    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
126    /// let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?;
127    /// m.run(&pool).await
128    /// #     })
129    /// # }
130    /// ```
131    pub async fn run<'a, A>(&self, migrator: A) -> Result<(), MigrateError>
132    where
133        A: Acquire<'a>,
134        <A::Connection as Deref>::Target: Migrate,
135    {
136        let mut conn = migrator.acquire().await?;
137        self.run_direct(&mut *conn).await
138    }
139
140    // Getting around the annoying "implementation of `Acquire` is not general enough" error
141    #[doc(hidden)]
142    pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
143    where
144        C: Migrate,
145    {
146        // lock the database for exclusive access by the migrator
147        if self.locking {
148            conn.lock().await?;
149        }
150
151        // creates [_migrations] table only if needed
152        // eventually this will likely migrate previous versions of the table
153        conn.ensure_migrations_table().await?;
154
155        let version = conn.dirty_version().await?;
156        if let Some(version) = version {
157            return Err(MigrateError::Dirty(version));
158        }
159
160        let applied_migrations = conn.list_applied_migrations().await?;
161        validate_applied_migrations(&applied_migrations, self)?;
162
163        let applied_migrations: HashMap<_, _> = applied_migrations
164            .into_iter()
165            .map(|m| (m.version, m))
166            .collect();
167
168        for migration in self.iter() {
169            if migration.migration_type.is_down_migration() {
170                continue;
171            }
172
173            match applied_migrations.get(&migration.version) {
174                Some(applied_migration) => {
175                    if migration.checksum != applied_migration.checksum {
176                        return Err(MigrateError::VersionMismatch(migration.version));
177                    }
178                }
179                None => {
180                    conn.apply(migration).await?;
181                }
182            }
183        }
184
185        // unlock the migrator to allow other migrators to run
186        // but do nothing as we already migrated
187        if self.locking {
188            conn.unlock().await?;
189        }
190
191        Ok(())
192    }
193
194    /// Run down migrations against the database until a specific version.
195    ///
196    /// # Examples
197    ///
198    /// ```rust,no_run
199    /// # use sqlx::migrate::MigrateError;
200    /// # fn main() -> Result<(), MigrateError> {
201    /// #     sqlx::__rt::test_block_on(async move {
202    /// use sqlx::migrate::Migrator;
203    /// use sqlx::sqlite::SqlitePoolOptions;
204    ///
205    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
206    /// let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?;
207    /// m.undo(&pool, 4).await
208    /// #     })
209    /// # }
210    /// ```
211    pub async fn undo<'a, A>(&self, migrator: A, target: i64) -> Result<(), MigrateError>
212    where
213        A: Acquire<'a>,
214        <A::Connection as Deref>::Target: Migrate,
215    {
216        let mut conn = migrator.acquire().await?;
217
218        // lock the database for exclusive access by the migrator
219        if self.locking {
220            conn.lock().await?;
221        }
222
223        // creates [_migrations] table only if needed
224        // eventually this will likely migrate previous versions of the table
225        conn.ensure_migrations_table().await?;
226
227        let version = conn.dirty_version().await?;
228        if let Some(version) = version {
229            return Err(MigrateError::Dirty(version));
230        }
231
232        let applied_migrations = conn.list_applied_migrations().await?;
233        validate_applied_migrations(&applied_migrations, self)?;
234
235        let applied_migrations: HashMap<_, _> = applied_migrations
236            .into_iter()
237            .map(|m| (m.version, m))
238            .collect();
239
240        for migration in self
241            .iter()
242            .rev()
243            .filter(|m| m.migration_type.is_down_migration())
244            .filter(|m| applied_migrations.contains_key(&m.version))
245            .filter(|m| m.version > target)
246        {
247            conn.revert(migration).await?;
248        }
249
250        // unlock the migrator to allow other migrators to run
251        // but do nothing as we already migrated
252        if self.locking {
253            conn.unlock().await?;
254        }
255
256        Ok(())
257    }
258}