tmc_langs_framework/
archive.rs

1//! Contains types that abstract over the various archive formats.
2
3use crate::TmcError;
4use blake3::{Hash, Hasher};
5use serde::Deserialize;
6use std::{
7    fmt::Display,
8    io::{BufReader, Cursor, Read, Seek, Write},
9    ops::ControlFlow::{self, Break},
10    path::{Path, PathBuf},
11    str::FromStr,
12};
13use tar::Builder;
14use tmc_langs_util::file_util;
15use walkdir::WalkDir;
16use zip::{DateTime, ZipWriter, write::SimpleFileOptions};
17
18/// Wrapper unifying the API of all the different compression formats supported by langs.
19/// Unfortunately the API is more complicated due to tar only supporting iterating through the files one by one,
20/// while zip only supports accessing by index.
21pub struct Archive<T: Read + Seek>(ArchiveInner<T>);
22
23enum ArchiveInner<T: Read + Seek> {
24    Tar(tar::Archive<T>),
25    TarZstd(tar::Archive<zstd::Decoder<'static, BufReader<T>>>),
26    Zip(zip::ZipArchive<T>),
27    // This variant is only used for dummy values when swapping out the inner archive when we only have a &mut Archive
28    Empty,
29}
30
31impl<T: Read + Seek> Archive<T> {
32    pub fn new(archive: T, compression: Compression) -> Result<Self, TmcError> {
33        match compression {
34            Compression::Tar => Ok(Self::tar(archive)),
35            Compression::TarZstd => Self::tar_zstd(archive),
36            Compression::Zip => Self::zip(archive),
37        }
38    }
39
40    pub fn tar(archive: T) -> Self {
41        let archive = tar::Archive::new(archive);
42        Self(ArchiveInner::Tar(archive))
43    }
44
45    pub fn tar_zstd(archive: T) -> Result<Self, TmcError> {
46        let archive = zstd::Decoder::new(archive).map_err(TmcError::ZstdRead)?;
47        let archive = tar::Archive::new(archive);
48        Ok(Self(ArchiveInner::TarZstd(archive)))
49    }
50
51    pub fn zip(archive: T) -> Result<Self, TmcError> {
52        let archive = zip::ZipArchive::new(archive)?;
53        Ok(Self(ArchiveInner::Zip(archive)))
54    }
55
56    pub fn extract(self, target_directory: &Path) -> Result<(), TmcError> {
57        match self {
58            Self(ArchiveInner::Tar(mut tar)) => {
59                tar.unpack(target_directory).map_err(TmcError::TarRead)?
60            }
61            Self(ArchiveInner::TarZstd(mut zstd)) => {
62                zstd.unpack(target_directory).map_err(TmcError::TarRead)?
63            }
64            Self(ArchiveInner::Zip(mut zip)) => zip.extract(target_directory)?,
65            Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
66        }
67        Ok(())
68    }
69
70    pub fn iter(&mut self) -> Result<ArchiveIterator<'_, T>, TmcError> {
71        self.reset()?;
72        match self {
73            Self(ArchiveInner::Tar(archive)) => {
74                let iter =
75                    ArchiveIterator::Tar(archive.entries_with_seek().map_err(TmcError::TarRead)?);
76                Ok(iter)
77            }
78            Self(ArchiveInner::TarZstd(archive)) => {
79                let iter = ArchiveIterator::TarZstd(archive.entries().map_err(TmcError::TarRead)?);
80                Ok(iter)
81            }
82            Self(ArchiveInner::Zip(archive)) => Ok(ArchiveIterator::Zip(0, archive)),
83            Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
84        }
85    }
86
87    pub fn by_path(&mut self, path: &str) -> Result<Entry<'_, T>, TmcError> {
88        self.reset()?;
89        match self {
90            Self(ArchiveInner::Tar(archive)) => {
91                for entry in archive.entries_with_seek().map_err(TmcError::TarRead)? {
92                    let entry = entry.map_err(TmcError::TarRead)?;
93                    if entry.path().map_err(TmcError::TarRead)? == Path::new(path) {
94                        return Ok(Entry::Tar(entry));
95                    }
96                }
97                Err(TmcError::TarRead(std::io::Error::other(format!(
98                    "Could not find {path} in tar"
99                ))))
100            }
101            Self(ArchiveInner::TarZstd(archive)) => {
102                for entry in archive.entries().map_err(TmcError::TarRead)? {
103                    let entry = entry.map_err(TmcError::TarRead)?;
104                    if entry.path().map_err(TmcError::TarRead)? == Path::new(path) {
105                        return Ok(Entry::TarZstd(entry));
106                    }
107                }
108                Err(TmcError::TarRead(std::io::Error::other(format!(
109                    "Could not find {path} in tar"
110                ))))
111            }
112            Self(ArchiveInner::Zip(archive)) => {
113                archive.by_name(path).map(Entry::Zip).map_err(Into::into)
114            }
115            Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
116        }
117    }
118
119    pub fn compression(&self) -> Compression {
120        match self {
121            Self(ArchiveInner::Tar(_)) => Compression::Tar,
122            Self(ArchiveInner::TarZstd(_)) => Compression::TarZstd,
123            Self(ArchiveInner::Zip(_)) => Compression::Zip,
124            Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
125        }
126    }
127
128    pub fn into_inner(self) -> T {
129        match self {
130            Self(ArchiveInner::Tar(archive)) => archive.into_inner(),
131            Self(ArchiveInner::TarZstd(archive)) => archive.into_inner().finish().into_inner(),
132            Self(ArchiveInner::Zip(archive)) => archive.into_inner(),
133            Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
134        }
135    }
136
137    /// tar's entries functions require the archive's position to be at 0,
138    /// but resetting the position is awkward, hence this helper function
139    fn reset(&mut self) -> Result<(), TmcError> {
140        let mut swap = ArchiveInner::Empty;
141        std::mem::swap(&mut self.0, &mut swap);
142        let mut swap = match swap {
143            ArchiveInner::Tar(archive) => {
144                let mut inner = archive.into_inner();
145                inner
146                    .seek(std::io::SeekFrom::Start(0))
147                    .map_err(TmcError::Seek)?;
148                ArchiveInner::Tar(tar::Archive::new(inner))
149            }
150            ArchiveInner::TarZstd(archive) => {
151                let mut inner = archive.into_inner().finish().into_inner();
152                inner
153                    .seek(std::io::SeekFrom::Start(0))
154                    .map_err(TmcError::Seek)?;
155                let decoder = zstd::Decoder::new(inner).map_err(TmcError::ZstdRead)?;
156                ArchiveInner::TarZstd(tar::Archive::new(decoder))
157            }
158            ArchiveInner::Zip(_) => {
159                // no-op
160                swap
161            }
162            ArchiveInner::Empty => unreachable!("This is a bug."),
163        };
164        // swap the value back in
165        std::mem::swap(&mut self.0, &mut swap);
166        Ok(())
167    }
168}
169
170pub enum ArchiveIterator<'a, T: Read + Seek> {
171    Tar(tar::Entries<'a, T>),
172    TarZstd(tar::Entries<'a, zstd::Decoder<'static, BufReader<T>>>),
173    Zip(usize, &'a mut zip::ZipArchive<T>),
174}
175
176impl<T: Read + Seek> ArchiveIterator<'_, T> {
177    /// Returns Break(None) when there's nothing left to iterate.
178    pub fn with_next<U, F: FnMut(Entry<'_, T>) -> Result<ControlFlow<Option<U>>, TmcError>>(
179        &mut self,
180        mut f: F,
181    ) -> Result<ControlFlow<Option<U>>, TmcError> {
182        match self {
183            Self::Tar(iter) => {
184                let next = iter
185                    .next()
186                    .map(|e| e.map(Entry::Tar))
187                    .transpose()
188                    .map_err(TmcError::TarRead)?;
189                if let Some(next) = next {
190                    let res = f(next)?;
191                    Ok(res)
192                } else {
193                    Ok(Break(None))
194                }
195            }
196            Self::TarZstd(iter) => {
197                let next = iter
198                    .next()
199                    .map(|e| e.map(Entry::TarZstd))
200                    .transpose()
201                    .map_err(TmcError::TarRead)?;
202                if let Some(next) = next {
203                    let res = f(next)?;
204                    Ok(res)
205                } else {
206                    Ok(Break(None))
207                }
208            }
209            Self::Zip(i, archive) => {
210                if *i < archive.len() {
211                    let next = archive.by_index(*i)?;
212                    *i += 1;
213                    let res = f(Entry::Zip(next))?;
214                    Ok(res)
215                } else {
216                    Ok(Break(None))
217                }
218            }
219        }
220    }
221}
222
223pub enum Entry<'a, T: Read> {
224    Tar(tar::Entry<'a, T>),
225    TarZstd(tar::Entry<'a, zstd::Decoder<'static, BufReader<T>>>),
226    Zip(zip::read::ZipFile<'a, T>),
227}
228
229impl<T: Read> Entry<'_, T> {
230    pub fn path(&self) -> Result<PathBuf, TmcError> {
231        match self {
232            Self::Tar(entry) => {
233                let name = entry.path().map_err(TmcError::TarRead)?.into_owned();
234                Ok(name)
235            }
236            Self::TarZstd(entry) => {
237                let name = entry.path().map_err(TmcError::TarRead)?.into_owned();
238                Ok(name)
239            }
240            Self::Zip(entry) => {
241                let name = entry
242                    .enclosed_name()
243                    .ok_or_else(|| TmcError::ZipName(entry.name().to_string()))?
244                    .to_path_buf();
245                Ok(name)
246            }
247        }
248    }
249
250    pub fn is_dir(&self) -> bool {
251        match self {
252            Self::Tar(entry) => matches!(entry.header().entry_type(), tar::EntryType::Directory),
253            Self::TarZstd(entry) => {
254                matches!(entry.header().entry_type(), tar::EntryType::Directory)
255            }
256            Self::Zip(entry) => entry.is_dir(),
257        }
258    }
259
260    pub fn is_file(&self) -> bool {
261        match self {
262            Self::Tar(entry) => matches!(entry.header().entry_type(), tar::EntryType::Regular),
263            Self::TarZstd(entry) => {
264                matches!(entry.header().entry_type(), tar::EntryType::Regular)
265            }
266            Self::Zip(entry) => entry.is_file(),
267        }
268    }
269}
270
271impl<T: Read> Read for Entry<'_, T> {
272    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
273        match self {
274            Self::Tar(archive) => archive.read(buf),
275            Self::TarZstd(archive) => archive.read(buf),
276            Self::Zip(archive) => archive.read(buf),
277        }
278    }
279}
280
281/// Supported compression methods.
282#[derive(Debug, Clone, Copy, Deserialize)]
283#[cfg_attr(feature = "ts-rs", derive(ts_rs::TS))]
284pub enum Compression {
285    /// .tar
286    #[serde(rename = "tar")]
287    Tar,
288    /// .zip
289    #[serde(rename = "zip")]
290    Zip,
291    /// .tar.ztd
292    #[serde(rename = "zstd")]
293    TarZstd,
294}
295
296impl Compression {
297    pub fn compress(self, path: &Path, hash: bool) -> Result<(Vec<u8>, Option<Hash>), TmcError> {
298        let mut builder = ArchiveBuilder::new(Cursor::new(Vec::new()), self, None, true, hash);
299        walk_dir_for_compression(path, |entry, relative_path| {
300            if entry.path().is_dir() {
301                builder.add_directory(entry.path(), relative_path)?;
302            } else if entry.path().is_file() {
303                builder.add_file(entry.path(), relative_path)?;
304            }
305            Ok(())
306        })?;
307        let (cursor, hash) = builder.finish()?;
308        Ok((cursor.into_inner(), hash))
309    }
310}
311
312fn walk_dir_for_compression(
313    root: &Path,
314    mut f: impl FnMut(&walkdir::DirEntry, &str) -> Result<(), TmcError>,
315) -> Result<(), TmcError> {
316    let parent = root.parent().map(PathBuf::from).unwrap_or_default();
317    for entry in WalkDir::new(root)
318        .sort_by_file_name()
319        .into_iter()
320        // filter windows lock files
321        .filter_entry(|e| e.file_name() != file_util::LOCK_FILE_NAME)
322    {
323        let entry = entry?;
324        let stripped = entry
325            .path()
326            .strip_prefix(&parent)
327            .expect("entries are within parent");
328        let path_str = stripped
329            .to_str()
330            .ok_or_else(|| TmcError::InvalidUtf8(stripped.to_path_buf()))?;
331        f(&entry, path_str)?;
332    }
333    Ok(())
334}
335
336impl Display for Compression {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        match self {
339            Self::Tar => write!(f, "tar"),
340            Self::Zip => write!(f, "zip"),
341            Self::TarZstd => write!(f, "zstd"),
342        }
343    }
344}
345
346impl FromStr for Compression {
347    type Err = &'static str;
348
349    fn from_str(s: &str) -> Result<Self, Self::Err> {
350        let format = match s {
351            "tar" => Compression::Tar,
352            "zip" => Compression::Zip,
353            "zstd" => Compression::TarZstd,
354            _ => return Err("invalid format"),
355        };
356        Ok(format)
357    }
358}
359
360pub struct ArchiveBuilder<W: Write + Seek> {
361    size_limit_b: Option<usize>,
362    size_limit_mb: Option<u32>,
363    size_total_b: usize,
364    hasher: Option<Hasher>,
365    kind: Kind<W>,
366}
367
368enum Kind<W: Write + Seek> {
369    Tar {
370        builder: Builder<W>,
371    },
372    TarZstd {
373        writer: W,
374        builder: Builder<Cursor<Vec<u8>>>,
375    },
376    Zip {
377        builder: Box<ZipWriter<W>>,
378        deterministic: bool,
379    },
380}
381
382impl<W: Write + Seek> ArchiveBuilder<W> {
383    pub fn new(
384        writer: W,
385        compression: Compression,
386        size_limit_mb: Option<u32>,
387        deterministic: bool,
388        hash: bool,
389    ) -> Self {
390        let size_limit_b = size_limit_mb.map(|slmb| {
391            usize::try_from(slmb)
392                .unwrap_or(usize::MAX)
393                .saturating_mul(1000 * 1000)
394        });
395        let hasher = if hash { Some(Hasher::new()) } else { None };
396        let kind = match compression {
397            Compression::Tar => {
398                let mut builder = Builder::new(writer);
399                if deterministic {
400                    builder.mode(tar::HeaderMode::Deterministic);
401                }
402                Kind::Tar { builder }
403            }
404            Compression::TarZstd => {
405                let mut builder = Builder::new(Cursor::new(vec![]));
406                if deterministic {
407                    builder.mode(tar::HeaderMode::Deterministic);
408                }
409                Kind::TarZstd { writer, builder }
410            }
411            Compression::Zip => Kind::Zip {
412                builder: Box::new(ZipWriter::new(writer)),
413                deterministic,
414            },
415        };
416        Self {
417            size_limit_b,
418            size_limit_mb,
419            size_total_b: 0,
420            hasher,
421            kind,
422        }
423    }
424
425    /// Does not include any files within the directory.
426    pub fn add_directory(&mut self, source: &Path, path_in_archive: &str) -> Result<(), TmcError> {
427        log::trace!("adding directory {path_in_archive}");
428        self.hash(path_in_archive.as_bytes());
429        match &mut self.kind {
430            Kind::Tar { builder } => {
431                builder
432                    .append_dir(path_in_archive, source)
433                    .map_err(TmcError::TarWrite)?;
434            }
435            Kind::TarZstd { builder, .. } => {
436                builder
437                    .append_dir(path_in_archive, source)
438                    .map_err(TmcError::TarWrite)?;
439            }
440            Kind::Zip {
441                builder,
442                deterministic,
443            } => builder.add_directory(path_in_archive, zip_file_options(*deterministic))?,
444        }
445        Ok(())
446    }
447
448    pub fn add_file(&mut self, source: &Path, path_in_archive: &str) -> Result<(), TmcError> {
449        log::trace!("writing file {} as {}", source.display(), path_in_archive);
450        self.hash(path_in_archive.as_bytes());
451        let bytes = file_util::read_file(source)?;
452        self.size_total_b += bytes.len();
453        if let (Some(size_limit_b), Some(size_limit_mb)) = (self.size_limit_b, self.size_limit_mb) {
454            if self.size_total_b > size_limit_b {
455                return Err(TmcError::ArchiveSizeLimitExceeded {
456                    limit: size_limit_mb,
457                });
458            }
459        }
460        self.hash(&bytes);
461        match &mut self.kind {
462            Kind::Tar { builder } => builder
463                .append_path_with_name(source, path_in_archive)
464                .map_err(TmcError::TarWrite)?,
465            Kind::TarZstd { builder, .. } => builder
466                .append_path_with_name(source, path_in_archive)
467                .map_err(TmcError::TarWrite)?,
468            Kind::Zip {
469                builder,
470                deterministic,
471            } => {
472                builder.start_file(path_in_archive, zip_file_options(*deterministic))?;
473                builder
474                    .write_all(&bytes)
475                    .map_err(|e| TmcError::ZipWrite(source.into(), e))?;
476            }
477        }
478        Ok(())
479    }
480
481    pub fn finish(self) -> Result<(W, Option<Hash>), TmcError> {
482        let res = match self.kind {
483            Kind::Tar { builder } => builder.into_inner().map_err(TmcError::TarWrite)?,
484            Kind::TarZstd {
485                mut writer,
486                builder,
487            } => {
488                let tar_data = builder.into_inner().map_err(TmcError::TarWrite)?;
489                zstd::stream::copy_encode(tar_data.get_ref().as_slice(), &mut writer, 0)
490                    .map_err(TmcError::ZstdWrite)?;
491                writer
492            }
493            Kind::Zip { builder, .. } => builder.finish()?,
494        };
495        let hash = self.hasher.map(|h| h.finalize());
496        Ok((res, hash))
497    }
498
499    fn hash(&mut self, input: &[u8]) {
500        self.hasher.as_mut().map(|h| h.update(input));
501    }
502}
503
504fn zip_file_options(deterministic: bool) -> SimpleFileOptions {
505    let file_options = SimpleFileOptions::default().unix_permissions(0o755);
506    if deterministic {
507        file_options.last_modified_time(
508            DateTime::from_date_and_time(2023, 1, 1, 0, 0, 0).expect("known to work"),
509        )
510    } else {
511        file_options
512    }
513}
514
515#[cfg(test)]
516mod test {
517    use super::*;
518    use tempfile::NamedTempFile;
519
520    #[test]
521    fn exceeding_file_limit_causes_error() {
522        let mut builder = ArchiveBuilder::new(
523            Cursor::new(Vec::new()),
524            Compression::Tar,
525            Some(1),
526            true,
527            true,
528        );
529
530        // write exactly 1MB, OK
531        let mut temp = NamedTempFile::new().unwrap();
532        temp.write_all("a".as_bytes().repeat(1000 * 1000).as_slice())
533            .unwrap();
534        builder
535            .add_file(temp.path(), "file")
536            .expect("should not be over size limit");
537
538        // write one byte more, error
539        let mut temp = NamedTempFile::new().unwrap();
540        temp.write_all("a".as_bytes()).unwrap();
541        assert!(
542            builder.add_file(temp.path(), "file").is_err(),
543            "should be over size limit"
544        );
545    }
546}