1use crate::database::Database;
4
5use crate::query_builder::QueryBuilder;
6
7use indexmap::set::IndexSet;
8use std::cmp;
9use std::collections::{BTreeMap, HashMap};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13pub type Result<T, E = FixtureError> = std::result::Result<T, E>;
14
15pub struct FixtureSnapshot<DB> {
21 tables: BTreeMap<TableName, Table>,
22 db: PhantomData<DB>,
23}
24
25#[derive(Debug, thiserror::Error)]
26#[error("could not create fixture: {0}")]
27pub struct FixtureError(String);
28
29pub struct Fixture<DB> {
30 ops: Vec<FixtureOp>,
31 db: PhantomData<DB>,
32}
33
34enum FixtureOp {
35 Insert {
36 table: TableName,
37 columns: Vec<ColumnName>,
38 rows: Vec<Vec<Value>>,
39 },
40 }
42
43type TableName = Arc<str>;
44type ColumnName = Arc<str>;
45type Value = String;
46
47struct Table {
48 name: TableName,
49 columns: IndexSet<ColumnName>,
50 rows: Vec<Vec<Value>>,
51 foreign_keys: HashMap<ColumnName, (TableName, ColumnName)>,
52}
53
54macro_rules! fixture_assert (
55 ($cond:expr, $msg:literal $($arg:tt)*) => {
56 if !($cond) {
57 return Err(FixtureError(format!($msg $($arg)*)))
58 }
59 }
60);
61
62impl<DB: Database> FixtureSnapshot<DB> {
63 pub fn additive_fixture(&self) -> Result<Fixture<DB>> {
74 let visit_order = self.calculate_visit_order()?;
75
76 let mut ops = Vec::new();
77
78 for table_name in visit_order {
79 let table = self.tables.get(&table_name).unwrap();
80
81 ops.push(FixtureOp::Insert {
82 table: table_name,
83 columns: table.columns.iter().cloned().collect(),
84 rows: table.rows.clone(),
85 });
86 }
87
88 Ok(Fixture { ops, db: self.db })
89 }
90
91 fn calculate_visit_order(&self) -> Result<Vec<TableName>> {
96 let mut table_depths = HashMap::with_capacity(self.tables.len());
97 let mut visited_set = IndexSet::with_capacity(self.tables.len());
98
99 for table in self.tables.values() {
100 foreign_key_depth(&self.tables, table, &mut table_depths, &mut visited_set)?;
101 visited_set.clear();
102 }
103
104 let mut table_names: Vec<TableName> = table_depths.keys().cloned().collect();
105 table_names.sort_by_key(|name| table_depths.get(name).unwrap());
106 Ok(table_names)
107 }
108}
109
110#[allow(clippy::to_string_trait_impl)]
113impl<DB: Database> ToString for Fixture<DB>
114where
115 for<'a> <DB as Database>::Arguments<'a>: Default,
116{
117 fn to_string(&self) -> String {
118 let mut query = QueryBuilder::<DB>::new("");
119
120 for op in &self.ops {
121 match op {
122 FixtureOp::Insert {
123 table,
124 columns,
125 rows,
126 } => {
127 if columns.is_empty() || rows.is_empty() {
129 continue;
130 }
131
132 query.push(format_args!("INSERT INTO {table} ("));
133
134 let mut separated = query.separated(", ");
135
136 for column in columns {
137 separated.push(column);
138 }
139
140 query.push(")\n");
141
142 query.push_values(rows, |mut separated, row| {
143 for value in row {
144 separated.push(value);
145 }
146 });
147
148 query.push(";\n");
149 }
150 }
151 }
152
153 query.into_sql()
154 }
155}
156
157fn foreign_key_depth(
158 tables: &BTreeMap<TableName, Table>,
159 table: &Table,
160 depths: &mut HashMap<TableName, usize>,
161 visited_set: &mut IndexSet<TableName>,
162) -> Result<usize> {
163 if let Some(&depth) = depths.get(&table.name) {
164 return Ok(depth);
165 }
166
167 fixture_assert!(
169 visited_set.insert(table.name.clone()),
170 "foreign key cycle detected: {:?} -> {:?}",
171 visited_set,
172 table.name
173 );
174
175 let mut refdepth = 0;
176
177 for (colname, (refname, refcol)) in &table.foreign_keys {
178 let referenced = tables.get(refname).ok_or_else(|| {
179 FixtureError(format!(
180 "table {:?} in foreign key `{}.{} references {}.{}` does not exist",
181 refname, table.name, colname, refname, refcol
182 ))
183 })?;
184
185 refdepth = cmp::max(
186 refdepth,
187 foreign_key_depth(tables, referenced, depths, visited_set)?,
188 );
189 }
190
191 let depth = refdepth + 1;
192
193 depths.insert(table.name.clone(), depth);
194
195 Ok(depth)
196}
197
198#[test]
199#[cfg(feature = "any")]
200fn test_additive_fixture() -> Result<()> {
201 use crate::any::Any;
203
204 let mut snapshot = FixtureSnapshot {
205 tables: BTreeMap::new(),
206 db: PhantomData::<Any>,
207 };
208
209 snapshot.tables.insert(
210 "foo".into(),
211 Table {
212 name: "foo".into(),
213 columns: ["foo_id", "foo_a", "foo_b"]
214 .into_iter()
215 .map(Arc::<str>::from)
216 .collect(),
217 rows: vec![vec!["1".into(), "'asdf'".into(), "true".into()]],
218 foreign_keys: HashMap::new(),
219 },
220 );
221
222 snapshot.tables.insert(
225 "bar".into(),
226 Table {
227 name: "bar".into(),
228 columns: ["bar_id", "foo_id", "bar_a", "bar_b"]
229 .into_iter()
230 .map(Arc::<str>::from)
231 .collect(),
232 rows: vec![vec![
233 "1234".into(),
234 "1".into(),
235 "'2022-07-22 23:27:48.775113301+00:00'".into(),
236 "3.14".into(),
237 ]],
238 foreign_keys: [("foo_id".into(), ("foo".into(), "foo_id".into()))]
239 .into_iter()
240 .collect(),
241 },
242 );
243
244 snapshot.tables.insert(
246 "baz".into(),
247 Table {
248 name: "baz".into(),
249 columns: ["baz_id", "bar_id", "foo_id", "baz_a", "baz_b"]
250 .into_iter()
251 .map(Arc::<str>::from)
252 .collect(),
253 rows: vec![vec![
254 "5678".into(),
255 "1234".into(),
256 "1".into(),
257 "'2022-07-22 23:27:48.775113301+00:00'".into(),
258 "3.14".into(),
259 ]],
260 foreign_keys: [
261 ("foo_id".into(), ("foo".into(), "foo_id".into())),
262 ("bar_id".into(), ("bar".into(), "bar_id".into())),
263 ]
264 .into_iter()
265 .collect(),
266 },
267 );
268
269 let fixture = snapshot.additive_fixture()?;
270
271 assert_eq!(
272 fixture.to_string(),
273 "INSERT INTO foo (foo_id, foo_a, foo_b)\n\
274 VALUES (1, 'asdf', true);\n\
275 INSERT INTO bar (bar_id, foo_id, bar_a, bar_b)\n\
276 VALUES (1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n\
277 INSERT INTO baz (baz_id, bar_id, foo_id, baz_a, baz_b)\n\
278 VALUES (5678, 1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n"
279 );
280
281 Ok(())
282}