1use crate::TmcError;
4use blake3::{Hash, Hasher};
5use serde::Deserialize;
6use std::{
7 fmt::Display,
8 io::{BufReader, Cursor, Read, Seek, Write},
9 ops::ControlFlow::{self, Break},
10 path::{Path, PathBuf},
11 str::FromStr,
12};
13use tar::Builder;
14use tmc_langs_util::file_util;
15use walkdir::WalkDir;
16use zip::{DateTime, ZipWriter, write::SimpleFileOptions};
17
18pub struct Archive<T: Read + Seek>(ArchiveInner<T>);
22
23enum ArchiveInner<T: Read + Seek> {
24 Tar(tar::Archive<T>),
25 TarZstd(tar::Archive<zstd::Decoder<'static, BufReader<T>>>),
26 Zip(zip::ZipArchive<T>),
27 Empty,
29}
30
31impl<T: Read + Seek> Archive<T> {
32 pub fn new(archive: T, compression: Compression) -> Result<Self, TmcError> {
33 match compression {
34 Compression::Tar => Ok(Self::tar(archive)),
35 Compression::TarZstd => Self::tar_zstd(archive),
36 Compression::Zip => Self::zip(archive),
37 }
38 }
39
40 pub fn tar(archive: T) -> Self {
41 let archive = tar::Archive::new(archive);
42 Self(ArchiveInner::Tar(archive))
43 }
44
45 pub fn tar_zstd(archive: T) -> Result<Self, TmcError> {
46 let archive = zstd::Decoder::new(archive).map_err(TmcError::ZstdRead)?;
47 let archive = tar::Archive::new(archive);
48 Ok(Self(ArchiveInner::TarZstd(archive)))
49 }
50
51 pub fn zip(archive: T) -> Result<Self, TmcError> {
52 let archive = zip::ZipArchive::new(archive)?;
53 Ok(Self(ArchiveInner::Zip(archive)))
54 }
55
56 pub fn extract(self, target_directory: &Path) -> Result<(), TmcError> {
57 match self {
58 Self(ArchiveInner::Tar(mut tar)) => {
59 tar.unpack(target_directory).map_err(TmcError::TarRead)?
60 }
61 Self(ArchiveInner::TarZstd(mut zstd)) => {
62 zstd.unpack(target_directory).map_err(TmcError::TarRead)?
63 }
64 Self(ArchiveInner::Zip(mut zip)) => zip.extract(target_directory)?,
65 Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
66 }
67 Ok(())
68 }
69
70 pub fn iter(&mut self) -> Result<ArchiveIterator<'_, T>, TmcError> {
71 self.reset()?;
72 match self {
73 Self(ArchiveInner::Tar(archive)) => {
74 let iter =
75 ArchiveIterator::Tar(archive.entries_with_seek().map_err(TmcError::TarRead)?);
76 Ok(iter)
77 }
78 Self(ArchiveInner::TarZstd(archive)) => {
79 let iter = ArchiveIterator::TarZstd(archive.entries().map_err(TmcError::TarRead)?);
80 Ok(iter)
81 }
82 Self(ArchiveInner::Zip(archive)) => Ok(ArchiveIterator::Zip(0, archive)),
83 Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
84 }
85 }
86
87 pub fn by_path(&mut self, path: &str) -> Result<Entry<'_, T>, TmcError> {
88 self.reset()?;
89 match self {
90 Self(ArchiveInner::Tar(archive)) => {
91 for entry in archive.entries_with_seek().map_err(TmcError::TarRead)? {
92 let entry = entry.map_err(TmcError::TarRead)?;
93 if entry.path().map_err(TmcError::TarRead)? == Path::new(path) {
94 return Ok(Entry::Tar(entry));
95 }
96 }
97 Err(TmcError::TarRead(std::io::Error::other(format!(
98 "Could not find {path} in tar"
99 ))))
100 }
101 Self(ArchiveInner::TarZstd(archive)) => {
102 for entry in archive.entries().map_err(TmcError::TarRead)? {
103 let entry = entry.map_err(TmcError::TarRead)?;
104 if entry.path().map_err(TmcError::TarRead)? == Path::new(path) {
105 return Ok(Entry::TarZstd(entry));
106 }
107 }
108 Err(TmcError::TarRead(std::io::Error::other(format!(
109 "Could not find {path} in tar"
110 ))))
111 }
112 Self(ArchiveInner::Zip(archive)) => {
113 archive.by_name(path).map(Entry::Zip).map_err(Into::into)
114 }
115 Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
116 }
117 }
118
119 pub fn compression(&self) -> Compression {
120 match self {
121 Self(ArchiveInner::Tar(_)) => Compression::Tar,
122 Self(ArchiveInner::TarZstd(_)) => Compression::TarZstd,
123 Self(ArchiveInner::Zip(_)) => Compression::Zip,
124 Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
125 }
126 }
127
128 pub fn into_inner(self) -> T {
129 match self {
130 Self(ArchiveInner::Tar(archive)) => archive.into_inner(),
131 Self(ArchiveInner::TarZstd(archive)) => archive.into_inner().finish().into_inner(),
132 Self(ArchiveInner::Zip(archive)) => archive.into_inner(),
133 Self(ArchiveInner::Empty) => unreachable!("This is a bug."),
134 }
135 }
136
137 fn reset(&mut self) -> Result<(), TmcError> {
140 let mut swap = ArchiveInner::Empty;
141 std::mem::swap(&mut self.0, &mut swap);
142 let mut swap = match swap {
143 ArchiveInner::Tar(archive) => {
144 let mut inner = archive.into_inner();
145 inner
146 .seek(std::io::SeekFrom::Start(0))
147 .map_err(TmcError::Seek)?;
148 ArchiveInner::Tar(tar::Archive::new(inner))
149 }
150 ArchiveInner::TarZstd(archive) => {
151 let mut inner = archive.into_inner().finish().into_inner();
152 inner
153 .seek(std::io::SeekFrom::Start(0))
154 .map_err(TmcError::Seek)?;
155 let decoder = zstd::Decoder::new(inner).map_err(TmcError::ZstdRead)?;
156 ArchiveInner::TarZstd(tar::Archive::new(decoder))
157 }
158 ArchiveInner::Zip(_) => {
159 swap
161 }
162 ArchiveInner::Empty => unreachable!("This is a bug."),
163 };
164 std::mem::swap(&mut self.0, &mut swap);
166 Ok(())
167 }
168}
169
170pub enum ArchiveIterator<'a, T: Read + Seek> {
171 Tar(tar::Entries<'a, T>),
172 TarZstd(tar::Entries<'a, zstd::Decoder<'static, BufReader<T>>>),
173 Zip(usize, &'a mut zip::ZipArchive<T>),
174}
175
176impl<T: Read + Seek> ArchiveIterator<'_, T> {
177 pub fn with_next<U, F: FnMut(Entry<'_, T>) -> Result<ControlFlow<Option<U>>, TmcError>>(
179 &mut self,
180 mut f: F,
181 ) -> Result<ControlFlow<Option<U>>, TmcError> {
182 match self {
183 Self::Tar(iter) => {
184 let next = iter
185 .next()
186 .map(|e| e.map(Entry::Tar))
187 .transpose()
188 .map_err(TmcError::TarRead)?;
189 if let Some(next) = next {
190 let res = f(next)?;
191 Ok(res)
192 } else {
193 Ok(Break(None))
194 }
195 }
196 Self::TarZstd(iter) => {
197 let next = iter
198 .next()
199 .map(|e| e.map(Entry::TarZstd))
200 .transpose()
201 .map_err(TmcError::TarRead)?;
202 if let Some(next) = next {
203 let res = f(next)?;
204 Ok(res)
205 } else {
206 Ok(Break(None))
207 }
208 }
209 Self::Zip(i, archive) => {
210 if *i < archive.len() {
211 let next = archive.by_index(*i)?;
212 *i += 1;
213 let res = f(Entry::Zip(next))?;
214 Ok(res)
215 } else {
216 Ok(Break(None))
217 }
218 }
219 }
220 }
221}
222
223pub enum Entry<'a, T: Read> {
224 Tar(tar::Entry<'a, T>),
225 TarZstd(tar::Entry<'a, zstd::Decoder<'static, BufReader<T>>>),
226 Zip(zip::read::ZipFile<'a, T>),
227}
228
229impl<T: Read> Entry<'_, T> {
230 pub fn path(&self) -> Result<PathBuf, TmcError> {
231 match self {
232 Self::Tar(entry) => {
233 let name = entry.path().map_err(TmcError::TarRead)?.into_owned();
234 Ok(name)
235 }
236 Self::TarZstd(entry) => {
237 let name = entry.path().map_err(TmcError::TarRead)?.into_owned();
238 Ok(name)
239 }
240 Self::Zip(entry) => {
241 let name = entry
242 .enclosed_name()
243 .ok_or_else(|| TmcError::ZipName(entry.name().to_string()))?
244 .to_path_buf();
245 Ok(name)
246 }
247 }
248 }
249
250 pub fn is_dir(&self) -> bool {
251 match self {
252 Self::Tar(entry) => matches!(entry.header().entry_type(), tar::EntryType::Directory),
253 Self::TarZstd(entry) => {
254 matches!(entry.header().entry_type(), tar::EntryType::Directory)
255 }
256 Self::Zip(entry) => entry.is_dir(),
257 }
258 }
259
260 pub fn is_file(&self) -> bool {
261 match self {
262 Self::Tar(entry) => matches!(entry.header().entry_type(), tar::EntryType::Regular),
263 Self::TarZstd(entry) => {
264 matches!(entry.header().entry_type(), tar::EntryType::Regular)
265 }
266 Self::Zip(entry) => entry.is_file(),
267 }
268 }
269}
270
271impl<T: Read> Read for Entry<'_, T> {
272 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
273 match self {
274 Self::Tar(archive) => archive.read(buf),
275 Self::TarZstd(archive) => archive.read(buf),
276 Self::Zip(archive) => archive.read(buf),
277 }
278 }
279}
280
281#[derive(Debug, Clone, Copy, Deserialize)]
283#[cfg_attr(feature = "ts-rs", derive(ts_rs::TS))]
284pub enum Compression {
285 #[serde(rename = "tar")]
287 Tar,
288 #[serde(rename = "zip")]
290 Zip,
291 #[serde(rename = "zstd")]
293 TarZstd,
294}
295
296impl Compression {
297 pub fn compress(self, path: &Path, hash: bool) -> Result<(Vec<u8>, Option<Hash>), TmcError> {
298 let mut builder = ArchiveBuilder::new(Cursor::new(Vec::new()), self, None, true, hash);
299 walk_dir_for_compression(path, |entry, relative_path| {
300 if entry.path().is_dir() {
301 builder.add_directory(entry.path(), relative_path)?;
302 } else if entry.path().is_file() {
303 builder.add_file(entry.path(), relative_path)?;
304 }
305 Ok(())
306 })?;
307 let (cursor, hash) = builder.finish()?;
308 Ok((cursor.into_inner(), hash))
309 }
310}
311
312fn walk_dir_for_compression(
313 root: &Path,
314 mut f: impl FnMut(&walkdir::DirEntry, &str) -> Result<(), TmcError>,
315) -> Result<(), TmcError> {
316 let parent = root.parent().map(PathBuf::from).unwrap_or_default();
317 for entry in WalkDir::new(root)
318 .sort_by_file_name()
319 .into_iter()
320 .filter_entry(|e| e.file_name() != file_util::LOCK_FILE_NAME)
322 {
323 let entry = entry?;
324 let stripped = entry
325 .path()
326 .strip_prefix(&parent)
327 .expect("entries are within parent");
328 let path_str = stripped
329 .to_str()
330 .ok_or_else(|| TmcError::InvalidUtf8(stripped.to_path_buf()))?;
331 f(&entry, path_str)?;
332 }
333 Ok(())
334}
335
336impl Display for Compression {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 match self {
339 Self::Tar => write!(f, "tar"),
340 Self::Zip => write!(f, "zip"),
341 Self::TarZstd => write!(f, "zstd"),
342 }
343 }
344}
345
346impl FromStr for Compression {
347 type Err = &'static str;
348
349 fn from_str(s: &str) -> Result<Self, Self::Err> {
350 let format = match s {
351 "tar" => Compression::Tar,
352 "zip" => Compression::Zip,
353 "zstd" => Compression::TarZstd,
354 _ => return Err("invalid format"),
355 };
356 Ok(format)
357 }
358}
359
360pub struct ArchiveBuilder<W: Write + Seek> {
361 size_limit_b: Option<usize>,
362 size_limit_mb: Option<u32>,
363 size_total_b: usize,
364 hasher: Option<Hasher>,
365 kind: Kind<W>,
366}
367
368enum Kind<W: Write + Seek> {
369 Tar {
370 builder: Builder<W>,
371 },
372 TarZstd {
373 writer: W,
374 builder: Builder<Cursor<Vec<u8>>>,
375 },
376 Zip {
377 builder: Box<ZipWriter<W>>,
378 deterministic: bool,
379 },
380}
381
382impl<W: Write + Seek> ArchiveBuilder<W> {
383 pub fn new(
384 writer: W,
385 compression: Compression,
386 size_limit_mb: Option<u32>,
387 deterministic: bool,
388 hash: bool,
389 ) -> Self {
390 let size_limit_b = size_limit_mb.map(|slmb| {
391 usize::try_from(slmb)
392 .unwrap_or(usize::MAX)
393 .saturating_mul(1000 * 1000)
394 });
395 let hasher = if hash { Some(Hasher::new()) } else { None };
396 let kind = match compression {
397 Compression::Tar => {
398 let mut builder = Builder::new(writer);
399 if deterministic {
400 builder.mode(tar::HeaderMode::Deterministic);
401 }
402 Kind::Tar { builder }
403 }
404 Compression::TarZstd => {
405 let mut builder = Builder::new(Cursor::new(vec![]));
406 if deterministic {
407 builder.mode(tar::HeaderMode::Deterministic);
408 }
409 Kind::TarZstd { writer, builder }
410 }
411 Compression::Zip => Kind::Zip {
412 builder: Box::new(ZipWriter::new(writer)),
413 deterministic,
414 },
415 };
416 Self {
417 size_limit_b,
418 size_limit_mb,
419 size_total_b: 0,
420 hasher,
421 kind,
422 }
423 }
424
425 pub fn add_directory(&mut self, source: &Path, path_in_archive: &str) -> Result<(), TmcError> {
427 log::trace!("adding directory {path_in_archive}");
428 self.hash(path_in_archive.as_bytes());
429 match &mut self.kind {
430 Kind::Tar { builder } => {
431 builder
432 .append_dir(path_in_archive, source)
433 .map_err(TmcError::TarWrite)?;
434 }
435 Kind::TarZstd { builder, .. } => {
436 builder
437 .append_dir(path_in_archive, source)
438 .map_err(TmcError::TarWrite)?;
439 }
440 Kind::Zip {
441 builder,
442 deterministic,
443 } => builder.add_directory(path_in_archive, zip_file_options(*deterministic))?,
444 }
445 Ok(())
446 }
447
448 pub fn add_file(&mut self, source: &Path, path_in_archive: &str) -> Result<(), TmcError> {
449 log::trace!("writing file {} as {}", source.display(), path_in_archive);
450 self.hash(path_in_archive.as_bytes());
451 let bytes = file_util::read_file(source)?;
452 self.size_total_b += bytes.len();
453 if let (Some(size_limit_b), Some(size_limit_mb)) = (self.size_limit_b, self.size_limit_mb) {
454 if self.size_total_b > size_limit_b {
455 return Err(TmcError::ArchiveSizeLimitExceeded {
456 limit: size_limit_mb,
457 });
458 }
459 }
460 self.hash(&bytes);
461 match &mut self.kind {
462 Kind::Tar { builder } => builder
463 .append_path_with_name(source, path_in_archive)
464 .map_err(TmcError::TarWrite)?,
465 Kind::TarZstd { builder, .. } => builder
466 .append_path_with_name(source, path_in_archive)
467 .map_err(TmcError::TarWrite)?,
468 Kind::Zip {
469 builder,
470 deterministic,
471 } => {
472 builder.start_file(path_in_archive, zip_file_options(*deterministic))?;
473 builder
474 .write_all(&bytes)
475 .map_err(|e| TmcError::ZipWrite(source.into(), e))?;
476 }
477 }
478 Ok(())
479 }
480
481 pub fn finish(self) -> Result<(W, Option<Hash>), TmcError> {
482 let res = match self.kind {
483 Kind::Tar { builder } => builder.into_inner().map_err(TmcError::TarWrite)?,
484 Kind::TarZstd {
485 mut writer,
486 builder,
487 } => {
488 let tar_data = builder.into_inner().map_err(TmcError::TarWrite)?;
489 zstd::stream::copy_encode(tar_data.get_ref().as_slice(), &mut writer, 0)
490 .map_err(TmcError::ZstdWrite)?;
491 writer
492 }
493 Kind::Zip { builder, .. } => builder.finish()?,
494 };
495 let hash = self.hasher.map(|h| h.finalize());
496 Ok((res, hash))
497 }
498
499 fn hash(&mut self, input: &[u8]) {
500 self.hasher.as_mut().map(|h| h.update(input));
501 }
502}
503
504fn zip_file_options(deterministic: bool) -> SimpleFileOptions {
505 let file_options = SimpleFileOptions::default().unix_permissions(0o755);
506 if deterministic {
507 file_options.last_modified_time(
508 DateTime::from_date_and_time(2023, 1, 1, 0, 0, 0).expect("known to work"),
509 )
510 } else {
511 file_options
512 }
513}
514
515#[cfg(test)]
516mod test {
517 use super::*;
518 use tempfile::NamedTempFile;
519
520 #[test]
521 fn exceeding_file_limit_causes_error() {
522 let mut builder = ArchiveBuilder::new(
523 Cursor::new(Vec::new()),
524 Compression::Tar,
525 Some(1),
526 true,
527 true,
528 );
529
530 let mut temp = NamedTempFile::new().unwrap();
532 temp.write_all("a".as_bytes().repeat(1000 * 1000).as_slice())
533 .unwrap();
534 builder
535 .add_file(temp.path(), "file")
536 .expect("should not be over size limit");
537
538 let mut temp = NamedTempFile::new().unwrap();
540 temp.write_all("a".as_bytes()).unwrap();
541 assert!(
542 builder.add_file(temp.path(), "file").is_err(),
543 "should be over size limit"
544 );
545 }
546}