Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion check_diff/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions check_diff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tempfile = "3"
walkdir = "2.5.0"
diffy = "0.4.0"
crossbeam-channel = "0.5.15"
259 changes: 197 additions & 62 deletions check_diff/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::env;
use std::fmt::{Debug, Display};
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::str::FromStr;
use tracing::{debug, error, info, trace};
use std::sync::{Arc, Mutex};
use tempfile::tempdir;
use tracing::{debug, info, trace, warn};
use walkdir::WalkDir;

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -411,6 +414,46 @@ fn create_config_arg<T: AsRef<str>>(configs: Option<&[T]>) -> Cow<'static, str>

Cow::Owned(result)
}

pub struct Repository<P> {
/// Name of the repository
name: String,
/// Path to the repository on the local file system
dir_path: P,
}

impl<P> Repository<P> {
/// Initialize a new Repository
pub fn new(git_url: &str, dir_path: P) -> Self {
let name = get_repo_name(git_url).to_string();
Self { name, dir_path }
}

/// Get the `name` of the repository
pub fn name(&self) -> &str {
&self.name
}

/// Get the absolute path to where this repository was cloned
pub fn path(&self) -> &Path
where
P: AsRef<Path>,
{
self.dir_path.as_ref()
}

/// Get the relative path of a file contained in this repository
pub fn relative_path<'f, F>(&self, file: &'f F) -> &'f Path
where
P: AsRef<Path>,
F: AsRef<Path>,
{
file.as_ref()
.strip_prefix(self.dir_path.as_ref())
.unwrap_or(file.as_ref())
}
}

/// Clone a git repository
///
/// Parameters:
Expand Down Expand Up @@ -641,76 +684,111 @@ pub fn search_for_rs_files(repo: &Path) -> impl Iterator<Item = PathBuf> {
})
}

/// Encapsulate the logic used to clone repositories for the diff check
pub fn clone_repositories_for_diff_check(
repositories: &[&str],
) -> Vec<Repository<tempfile::TempDir>> {
// Use a Hashmap to deduplicate any repositories
let map = Arc::new(Mutex::new(HashMap::new()));

std::thread::scope(|s| {
for url in repositories {
let map = Arc::clone(&map);

s.spawn(move || {
let repo_name = get_repo_name(url);
info!("Processing repo: {repo_name}");
let Ok(tmp_dir) = tempdir() else {
warn!(
"Failed to create a tempdir for {}. Can't check formatting diff for {}",
&url, repo_name
);
return;
};

let Ok(_) = clone_git_repo(url, tmp_dir.path()) else {
warn!(
"Failed to clone repo {}. Can't check formatting diff for {}",
&url, repo_name
);
return;
};

let repo = Repository::new(url, tmp_dir);
map.lock().unwrap().insert(repo_name.to_string(), repo);
});
}
});

let map = match Arc::into_inner(map)
.expect("All other threads are done")
.into_inner()
{
Ok(map) => map,
Err(e) => e.into_inner(),
};

map.into_values().collect()
}

/// Calculates the number of errors when running the compiled binary and the feature binary on the
/// repo specified with the specific configs.
pub fn check_diff<P: AsRef<Path>>(
pub fn check_diff_for_file<'repo, P: AsRef<Path>, F: AsRef<Path>>(
runners: &CheckDiffRunners<impl CodeFormatter, impl CodeFormatter>,
repo: P,
repo_url: &str,
) -> u8 {
let mut errors: u8 = 0;
let repo = repo.as_ref();
let iter = search_for_rs_files(repo);
for file in iter {
let relative_path = file.strip_prefix(repo).unwrap_or(&file);
let repo_name = get_repo_name(repo_url);

trace!(
"Formatting '{0}' file {0}/{1}",
repo_name,
relative_path.display()
);

match runners.create_diff(file.as_path()) {
Ok(diff) => {
if !diff.is_empty() {
error!(
"Diff found in '{0}' when formatting {0}/{1}\n{2}",
repo_name,
relative_path.display(),
diff,
);
errors = errors.saturating_add(1);
} else {
trace!(
"No diff found in '{0}' when formatting {0}/{1}",
repo_name,
relative_path.display(),
)
}
}
Err(CreateDiffError::MainRustfmtFailed(e)) => {
debug!(
"`main` rustfmt failed to format {}/{}\n{:?}",
repo_name,
relative_path.display(),
e,
);
continue;
}
Err(CreateDiffError::FeatureRustfmtFailed(e)) => {
debug!(
"`feature` rustfmt failed to format {}/{}\n{:?}",
repo_name,
relative_path.display(),
e,
);
continue;
}
Err(CreateDiffError::BothRustfmtFailed { src, feature }) => {
debug!(
"Both rustfmt binaries failed to format {}/{}\n{:?}\n{:?}",
repo: &'repo Repository<P>,
file: F,
) -> Result<(), (Diff, F, &'repo Repository<P>)> {
let relative_path = repo.relative_path(&file);
let repo_name = repo.name();

trace!(
"Formatting '{0}' file {0}/{1}",
repo_name,
relative_path.display()
);

match runners.create_diff(file.as_ref()) {
Ok(diff) => {
if !diff.is_empty() {
Err((diff, file, repo))
} else {
trace!(
"No diff found in '{0}' when formatting {0}/{1}",
repo_name,
relative_path.display(),
src,
feature,
);
continue;
Ok(())
}
}
Err(CreateDiffError::MainRustfmtFailed(e)) => {
debug!(
"`main` rustfmt failed to format {}/{}\n{:?}",
repo_name,
relative_path.display(),
e,
);
Ok(())
}
Err(CreateDiffError::FeatureRustfmtFailed(e)) => {
debug!(
"`feature` rustfmt failed to format {}/{}\n{:?}",
repo_name,
relative_path.display(),
e,
);
Ok(())
}
Err(CreateDiffError::BothRustfmtFailed { src, feature }) => {
debug!(
"Both rustfmt binaries failed to format {}/{}\n{:?}\n{:?}",
repo_name,
relative_path.display(),
src,
feature,
);
Ok(())
}
}

errors
}

/// parse out the repository name from a GitHub Repository name.
Expand All @@ -721,3 +799,60 @@ pub fn get_repo_name(git_url: &str) -> &str {
.unwrap_or(("", strip_git_prefix));
repo_name
}

pub fn check_diff<'repo, P, F, M>(
runners: &CheckDiffRunners<F, M>,
repositories: &'repo [Repository<P>],
worker_threads: std::num::NonZeroU8,
) -> Vec<(Diff, PathBuf, &'repo Repository<P>)>
where
P: AsRef<Path> + Sync + Send,
F: CodeFormatter + Sync,
M: CodeFormatter + Sync,
{
let (tx, rx) = crossbeam_channel::unbounded();

let errors = std::thread::scope(|s| {
// Spawn producer threads that find files to check
for repo in repositories.iter() {
let tx = tx.clone();
s.spawn(move || {
for file in search_for_rs_files(repo.path()) {
let _ = tx.send((file, repo));
}
});
}

// Drop the first `tx` we created. Now there's exactly one `tx` per producer thread so when
// each producer thread finishes the receiving threads will start to get Err(RecvError)
// when calling `rx.recv()` and they'll know to stop processing files.
// When all scoped threads end we'll know we're done with processing and we can return
// any errors we found to the caller.
drop(tx);

let errors = Arc::new(Mutex::new(Vec::with_capacity(10)));

// spawn receiver threads used to process all files:
for _ in 0..u8::from(worker_threads) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remark: I was about to ask why was 10 was picked specifically, this answers my question 😆

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just needed a placeholder until I added the option in a later commit. Sorry if the commit by commit review is confusing when I make changes in future commits.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, that is perfectly fine :)

let errors = Arc::clone(&errors);
let rx = rx.clone();
s.spawn(move || {
while let Ok((file, repo)) = rx.recv() {
if let Err(e) = check_diff_for_file(runners, repo, file) {
// Push errors to report on later
errors.lock().unwrap().push(e);
}
}
});
}
errors
});

match Arc::into_inner(errors)
.expect("All other threads are done")
.into_inner()
{
Ok(e) => e,
Err(e) => e.into_inner(),
}
}
Loading
Loading