diff --git a/crates/artifacts/solc/src/ast/misc.rs b/crates/artifacts/solc/src/ast/misc.rs index 6ec3187b..7144ddc5 100644 --- a/crates/artifacts/solc/src/ast/misc.rs +++ b/crates/artifacts/solc/src/ast/misc.rs @@ -4,7 +4,7 @@ use std::{fmt, fmt::Write, str::FromStr}; /// Represents the source location of a node: `::`. /// /// The `start`, `length` and `index` can be -1 which is represented as `None` -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SourceLocation { pub start: Option, pub length: Option, diff --git a/crates/compilers/src/cache.rs b/crates/compilers/src/cache.rs index 0d5d1613..8541a8b0 100644 --- a/crates/compilers/src/cache.rs +++ b/crates/compilers/src/cache.rs @@ -4,6 +4,7 @@ use crate::{ buildinfo::RawBuildInfo, compilers::{Compiler, CompilerSettings, Language}, output::Builds, + preprocessor::interface_representation_hash, resolver::GraphEdges, ArtifactFile, ArtifactOutput, Artifacts, ArtifactsMap, Graph, OutputContext, Project, ProjectPaths, ProjectPathsConfig, SourceCompilationKind, @@ -405,6 +406,8 @@ pub struct CacheEntry { pub last_modification_date: u64, /// hash to identify whether the content of the file changed pub content_hash: String, + /// hash of the interface representation of the file, if it's a source file + pub interface_repr_hash: Option, /// identifier name see [`foundry_compilers_core::utils::source_name()`] pub source_name: PathBuf, /// what config was set when compiling this file @@ -620,9 +623,18 @@ pub(crate) struct ArtifactsCacheInner<'a, T: ArtifactOutput, C: Compiler> { /// The file hashes. pub content_hashes: HashMap, + + /// The interface representations for source files. + pub interface_repr_hashes: HashMap, } impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { + /// Whther given file is a source file or a test/script file. + fn is_source_file(&self, file: &Path) -> bool { + !file.starts_with(&self.project.paths.tests) + && !file.starts_with(&self.project.paths.scripts) + } + /// Creates a new cache entry for the file fn create_cache_entry(&mut self, file: PathBuf, source: &Source) { let imports = self @@ -632,10 +644,14 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { .map(|import| strip_prefix(import, self.project.root()).into()) .collect(); + let interface_repr_hash = + self.is_source_file(&file).then(|| interface_representation_hash(source)); + let entry = CacheEntry { last_modification_date: CacheEntry::::read_last_modification_date(&file) .unwrap_or_default(), content_hash: source.content_hash(), + interface_repr_hash, source_name: strip_prefix(&file, self.project.root()).into(), compiler_settings: self.project.settings.clone(), imports, @@ -730,62 +746,77 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { return true; } - false - } - - // Walks over all cache entires, detects dirty files and removes them from cache. - fn find_and_remove_dirty(&mut self) { - fn populate_dirty_files( - file: &Path, - dirty_files: &mut HashSet, - edges: &GraphEdges, - ) { - for file in edges.importers(file) { - // If file is marked as dirty we either have already visited it or it was marked as - // dirty initially and will be visited at some point later. - if !dirty_files.contains(file) { - dirty_files.insert(file.to_path_buf()); - populate_dirty_files(file, dirty_files, edges); + // If any requested extra files are missing for any artifact, mark source as dirty to + // generate them + for artifacts in self.cached_artifacts.values() { + for artifacts in artifacts.values() { + for artifact_file in artifacts { + if self.project.artifacts_handler().is_dirty(artifact_file).unwrap_or(true) { + return true; + } } } } - // Iterate over existing cache entries. - let files = self.cache.files.keys().cloned().collect::>(); + false + } + // Walks over all cache entires, detects dirty files and removes them from cache. + fn find_and_remove_dirty(&mut self) { let mut sources = Sources::new(); - // Read all sources, marking entries as dirty on I/O errors. - for file in &files { - let Ok(source) = Source::read(file) else { - self.dirty_sources.insert(file.clone()); + // Read all sources, removing entries on I/O errors. + for file in self.cache.files.keys().cloned().collect::>() { + let Ok(source) = Source::read(&file) else { + self.cache.files.remove(&file); continue; }; sources.insert(file.clone(), source); } - // Build a temporary graph for walking imports. We need this because `self.edges` - // only contains graph data for in-scope sources but we are operating on cache entries. - if let Ok(graph) = Graph::::resolve_sources(&self.project.paths, sources) { - let (sources, edges) = graph.into_sources(); + // Calculate content hashes for later comparison. + self.fill_hashes(&sources); - // Calculate content hashes for later comparison. - self.fill_hashes(&sources); + // Pre-add all sources that are guaranteed to be dirty + for file in self.cache.files.keys() { + if self.is_dirty_impl(file, false) { + self.dirty_sources.insert(file.clone()); + } + } - // Pre-add all sources that are guaranteed to be dirty - for file in sources.keys() { - if self.is_dirty_impl(file) { + // Build a temporary graph for populating cache. We want to ensure that we preserve all just + // removed entries with updated data. We need separate graph for this because + // `self.edges` only contains graph data for in-scope sources but we are operating on cache + // entries. + let Ok(graph) = Graph::::resolve_sources(&self.project.paths, sources) + else { + // Purge all sources on graph resolution error. + self.cache.files.clear(); + return; + }; + + let (sources, edges) = graph.into_sources(); + + // Mark sources as dirty based on their imports + for file in sources.keys() { + if self.dirty_sources.contains(file) { + continue; + } + let is_src = self.is_source_file(file); + for import in edges.imports(file) { + // Any source file importing dirty source file is dirty. + if is_src && self.dirty_sources.contains(import) { + self.dirty_sources.insert(file.clone()); + break; + // For non-src files we mark them as dirty only if they import dirty non-src file + // or src file for which interface representation changed. + } else if !is_src + && self.dirty_sources.contains(import) + && (!self.is_source_file(import) || self.is_dirty_impl(import, true)) + { self.dirty_sources.insert(file.clone()); } } - - // Perform DFS to find direct/indirect importers of dirty files. - for file in self.dirty_sources.clone().iter() { - populate_dirty_files(file, &mut self.dirty_sources, &edges); - } - } else { - // Purge all sources on graph resolution error. - self.dirty_sources.extend(files); } // Remove all dirty files from cache. @@ -793,22 +824,43 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { debug!("removing dirty file from cache: {}", file.display()); self.cache.remove(file); } - } - fn is_dirty_impl(&self, file: &Path) -> bool { - let Some(hash) = self.content_hashes.get(file) else { - trace!("missing content hash"); - return true; - }; + // Create new entries for all source files + for (file, source) in sources { + if self.cache.files.contains_key(&file) { + continue; + } + self.create_cache_entry(file.clone(), &source); + } + } + + fn is_dirty_impl(&self, file: &Path, use_interface_repr: bool) -> bool { let Some(entry) = self.cache.entry(file) else { trace!("missing cache entry"); return true; }; - if entry.content_hash != *hash { - trace!("content hash changed"); - return true; + if use_interface_repr { + let Some(interface_hash) = self.interface_repr_hashes.get(file) else { + trace!("missing interface hash"); + return true; + }; + + if entry.interface_repr_hash.as_ref().map_or(true, |h| h != interface_hash) { + trace!("interface hash changed"); + return true; + }; + } else { + let Some(content_hash) = self.content_hashes.get(file) else { + trace!("missing content hash"); + return true; + }; + + if entry.content_hash != *content_hash { + trace!("content hash changed"); + return true; + } } if !self.project.settings.can_use_cached(&entry.compiler_settings) { @@ -816,18 +868,6 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { return true; } - // If any requested extra files are missing for any artifact, mark source as dirty to - // generate them - for artifacts in self.cached_artifacts.values() { - for artifacts in artifacts.values() { - for artifact_file in artifacts { - if self.project.artifacts_handler().is_dirty(artifact_file).unwrap_or(true) { - return true; - } - } - } - } - // all things match, can be reused false } @@ -838,6 +878,14 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCacheInner<'a, T, C> { if let hash_map::Entry::Vacant(entry) = self.content_hashes.entry(file.clone()) { entry.insert(source.content_hash()); } + // Fill interface representation hashes for source files + if self.is_source_file(&file) { + if let hash_map::Entry::Vacant(entry) = + self.interface_repr_hashes.entry(file.clone()) + { + entry.insert(interface_representation_hash(&source)); + } + } } } } @@ -921,6 +969,7 @@ impl<'a, T: ArtifactOutput, C: Compiler> ArtifactsCache<'a, T, C> { dirty_sources: Default::default(), content_hashes: Default::default(), sources_in_scope: Default::default(), + interface_repr_hashes: Default::default(), }; ArtifactsCache::Cached(cache) diff --git a/crates/compilers/src/compile/project.rs b/crates/compilers/src/compile/project.rs index b76c47a5..6ea0738b 100644 --- a/crates/compilers/src/compile/project.rs +++ b/crates/compilers/src/compile/project.rs @@ -109,16 +109,27 @@ use crate::{ output::{AggregatedCompilerOutput, Builds}, report, resolver::GraphEdges, - ArtifactOutput, CompilerSettings, Graph, Project, ProjectCompileOutput, Sources, + ArtifactOutput, CompilerSettings, Graph, Project, ProjectCompileOutput, ProjectPathsConfig, + Sources, }; use foundry_compilers_core::error::Result; use rayon::prelude::*; use semver::Version; -use std::{collections::HashMap, path::PathBuf, time::Instant}; +use std::{collections::HashMap, fmt::Debug, path::PathBuf, time::Instant}; /// A set of different Solc installations with their version and the sources to be compiled pub(crate) type VersionedSources = HashMap>; +/// Invoked before the actual compiler invocation and can override the input. +pub trait Preprocessor: Debug { + fn preprocess( + &self, + compiler: &C, + input: C::Input, + paths: &ProjectPathsConfig, + ) -> Result; +} + #[derive(Debug)] pub struct ProjectCompiler<'a, T: ArtifactOutput, C: Compiler> { /// Contains the relationship of the source files and their imports @@ -126,6 +137,8 @@ pub struct ProjectCompiler<'a, T: ArtifactOutput, C: Compiler> { project: &'a Project, /// how to compile all the sources sources: CompilerSources, + /// Optional preprocessor + preprocessor: Option>>, } impl<'a, T: ArtifactOutput, C: Compiler> ProjectCompiler<'a, T, C> { @@ -160,7 +173,11 @@ impl<'a, T: ArtifactOutput, C: Compiler> ProjectCompiler<'a, T, C> { sources, }; - Ok(Self { edges, project, sources }) + Ok(Self { edges, project, sources, preprocessor: None }) + } + + pub fn with_preprocessor(self, preprocessor: impl Preprocessor + 'static) -> Self { + Self { preprocessor: Some(Box::new(preprocessor)), ..self } } /// Compiles all the sources of the `Project` in the appropriate mode @@ -197,7 +214,7 @@ impl<'a, T: ArtifactOutput, C: Compiler> ProjectCompiler<'a, T, C> { /// - check cache fn preprocess(self) -> Result> { trace!("preprocessing"); - let Self { edges, project, mut sources } = self; + let Self { edges, project, mut sources, preprocessor } = self; // convert paths on windows to ensure consistency with the `CompilerOutput` `solc` emits, // which is unix style `/` @@ -207,7 +224,7 @@ impl<'a, T: ArtifactOutput, C: Compiler> ProjectCompiler<'a, T, C> { // retain and compile only dirty sources and all their imports sources.filter(&mut cache); - Ok(PreprocessedState { sources, cache }) + Ok(PreprocessedState { sources, cache, preprocessor }) } } @@ -221,15 +238,18 @@ struct PreprocessedState<'a, T: ArtifactOutput, C: Compiler> { /// Cache that holds `CacheEntry` objects if caching is enabled and the project is recompiled cache: ArtifactsCache<'a, T, C>, + + /// Optional preprocessor + preprocessor: Option>>, } impl<'a, T: ArtifactOutput, C: Compiler> PreprocessedState<'a, T, C> { /// advance to the next state by compiling all sources fn compile(self) -> Result> { trace!("compiling"); - let PreprocessedState { sources, mut cache } = self; + let PreprocessedState { sources, mut cache, preprocessor } = self; - let mut output = sources.compile(&mut cache)?; + let mut output = sources.compile(&mut cache, preprocessor)?; // source paths get stripped before handing them over to solc, so solc never uses absolute // paths, instead `--base-path ` is set. this way any metadata that's derived from @@ -410,6 +430,7 @@ impl CompilerSources { fn compile, T: ArtifactOutput>( self, cache: &mut ArtifactsCache<'_, T, C>, + preprocessor: Option>>, ) -> Result> { let project = cache.project(); let graph = cache.graph(); @@ -456,6 +477,10 @@ impl CompilerSources { input.strip_prefix(project.paths.root.as_path()); + if let Some(preprocessor) = preprocessor.as_ref() { + input = preprocessor.preprocess(&project.compiler, input, &project.paths)?; + } + jobs.push((input, actually_dirty)); } } diff --git a/crates/compilers/src/flatten.rs b/crates/compilers/src/flatten.rs index 45da3902..df4f5f87 100644 --- a/crates/compilers/src/flatten.rs +++ b/crates/compilers/src/flatten.rs @@ -17,9 +17,10 @@ use foundry_compilers_core::{ }; use itertools::Itertools; use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeSet, HashMap, HashSet}, hash::Hash, path::{Path, PathBuf}, + sync::Arc, }; use visitor::Walk; @@ -95,7 +96,7 @@ impl Visitor for ReferencesCollector { } fn visit_external_assembly_reference(&mut self, reference: &ExternalInlineAssemblyReference) { - let mut src = reference.src.clone(); + let mut src = reference.src; // If suffix is used in assembly reference (e.g. value.slot), it will be included into src. // However, we are only interested in the referenced name, thus we strip . part. @@ -112,47 +113,32 @@ impl Visitor for ReferencesCollector { /// Updates to be applied to the sources. /// source_path -> (start, end, new_value) -type Updates = HashMap>; +pub type Updates = HashMap>; -pub struct FlatteningResult<'a> { +pub struct FlatteningResult { /// Updated source in the order they shoud be written to the output file. sources: Vec, /// Pragmas that should be present in the target file. pragmas: Vec, /// License identifier that should be present in the target file. - license: Option<&'a str>, + license: Option, } -impl<'a> FlatteningResult<'a> { +impl FlatteningResult { fn new( - flattener: &Flattener, - mut updates: Updates, + mut flattener: Flattener, + updates: Updates, pragmas: Vec, - license: Option<&'a str>, + license: Option, ) -> Self { - let mut sources = Vec::new(); - - for path in &flattener.ordered_sources { - let mut content = flattener.sources.get(path).unwrap().content.as_bytes().to_vec(); - let mut offset: isize = 0; - if let Some(updates) = updates.remove(path) { - let mut updates = updates.iter().collect::>(); - updates.sort_by_key(|(start, _, _)| *start); - for (start, end, new_value) in updates { - let start = (*start as isize + offset) as usize; - let end = (*end as isize + offset) as usize; - - content.splice(start..end, new_value.bytes()); - offset += new_value.len() as isize - (end - start) as isize; - } - } - let content = format!( - "// {}\n{}", - path.strip_prefix(&flattener.project_root).unwrap_or(path).display(), - String::from_utf8(content).unwrap() - ); - sources.push(content); - } + apply_updates(&mut flattener.sources, updates); + + let sources = flattener + .ordered_sources + .iter() + .map(|path| flattener.sources.remove(path).unwrap().content) + .map(Arc::unwrap_or_clone) + .collect(); Self { sources, pragmas, license } } @@ -274,9 +260,10 @@ impl Flattener { /// 3. Remove all imports. /// 4. Remove all pragmas except for the ones in the target file. /// 5. Remove all license identifiers except for the one in the target file. - pub fn flatten(&self) -> String { + pub fn flatten(self) -> String { let mut updates = Updates::new(); + self.append_filenames(&mut updates); let top_level_names = self.rename_top_level_definitions(&mut updates); self.rename_contract_level_types_references(&top_level_names, &mut updates); self.remove_qualified_imports(&mut updates); @@ -289,15 +276,26 @@ impl Flattener { self.flatten_result(updates, target_pragmas, target_license).get_flattened_target() } - fn flatten_result<'a>( - &'a self, + fn flatten_result( + self, updates: Updates, target_pragmas: Vec, - target_license: Option<&'a str>, - ) -> FlatteningResult<'_> { + target_license: Option, + ) -> FlatteningResult { FlatteningResult::new(self, updates, target_pragmas, target_license) } + /// Appends a comment with the file name to the beginning of each source. + fn append_filenames(&self, updates: &mut Updates) { + for path in &self.ordered_sources { + updates.entry(path.clone()).or_default().insert(( + 0, + 0, + format!("// {}\n", path.strip_prefix(&self.project_root).unwrap_or(path).display()), + )); + } + } + /// Finds and goes over all references to file-level definitions and updates them to match /// definition name. This is needed for two reasons: /// 1. We want to rename all aliased or qualified imports. @@ -752,14 +750,14 @@ impl Flattener { /// Removes all license identifiers from all sources. Returns licesnse identifier from target /// file, if any. - fn process_licenses(&self, updates: &mut Updates) -> Option<&str> { + fn process_licenses(&self, updates: &mut Updates) -> Option { let mut target_license = None; for loc in &self.collect_licenses() { if loc.path == self.target { let license_line = self.read_location(loc); let license_start = license_line.find("SPDX-License-Identifier:").unwrap(); - target_license = Some(license_line[license_start..].trim()); + target_license = Some(license_line[license_start..].trim().to_string()); } updates.entry(loc.path.clone()).or_default().insert(( loc.start, @@ -887,3 +885,21 @@ pub fn combine_version_pragmas(pragmas: Vec<&str>) -> Option { None } + +pub fn apply_updates(sources: &mut Sources, mut updates: Updates) { + for (path, source) in sources { + if let Some(updates) = updates.remove(path) { + let mut offset = 0; + let mut content = source.content.as_bytes().to_vec(); + for (start, end, new_value) in updates { + let start = (start as isize + offset) as usize; + let end = (end as isize + offset) as usize; + + content.splice(start..end, new_value.bytes()); + offset += new_value.len() as isize - (end - start) as isize; + } + + source.content = Arc::new(String::from_utf8_lossy(&content).to_string()); + } + } +} diff --git a/crates/compilers/src/lib.rs b/crates/compilers/src/lib.rs index fc2f8c3a..1eebbfe9 100644 --- a/crates/compilers/src/lib.rs +++ b/crates/compilers/src/lib.rs @@ -24,6 +24,8 @@ pub use resolver::Graph; pub mod compilers; pub use compilers::*; +pub mod preprocessor; + mod compile; pub use compile::{ output::{AggregatedCompilerOutput, ProjectCompileOutput}, diff --git a/crates/compilers/src/preprocessor.rs b/crates/compilers/src/preprocessor.rs new file mode 100644 index 00000000..d79f6619 --- /dev/null +++ b/crates/compilers/src/preprocessor.rs @@ -0,0 +1,614 @@ +use super::project::Preprocessor; +use crate::{ + flatten::{apply_updates, Updates}, + multi::{MultiCompiler, MultiCompilerInput, MultiCompilerLanguage}, + solc::{SolcCompiler, SolcVersionedInput}, + Compiler, ProjectPathsConfig, Result, SolcError, +}; +use alloy_primitives::hex; +use foundry_compilers_artifacts::{ + ast::SourceLocation, + output_selection::OutputSelection, + visitor::{Visitor, Walk}, + ContractDefinitionPart, Expression, FunctionCall, FunctionKind, MemberAccess, NewExpression, + ParameterList, SolcLanguage, Source, SourceUnit, SourceUnitPart, Sources, TypeName, +}; +use foundry_compilers_core::utils; +use itertools::Itertools; +use md5::Digest; +use solang_parser::{diagnostics::Diagnostic, helpers::CodeLocation, pt}; +use std::{ + collections::{BTreeMap, BTreeSet}, + path::{Path, PathBuf}, +}; + +/// Removes parts of the contract which do not alter its interface: +/// - Internal functions +/// - External functions bodies +/// +/// Preserves all libraries and interfaces. +pub(crate) fn interface_representation(content: &str) -> Result> { + let (source_unit, _) = solang_parser::parse(content, 0)?; + let mut locs_to_remove = Vec::new(); + + for part in source_unit.0 { + if let pt::SourceUnitPart::ContractDefinition(contract) = part { + if matches!(contract.ty, pt::ContractTy::Interface(_) | pt::ContractTy::Library(_)) { + continue; + } + for part in contract.parts { + if let pt::ContractPart::FunctionDefinition(func) = part { + let is_exposed = func.ty == pt::FunctionTy::Function + && func.attributes.iter().any(|attr| { + matches!( + attr, + pt::FunctionAttribute::Visibility( + pt::Visibility::External(_) | pt::Visibility::Public(_) + ) + ) + }) + || matches!( + func.ty, + pt::FunctionTy::Constructor + | pt::FunctionTy::Fallback + | pt::FunctionTy::Receive + ); + + if !is_exposed { + locs_to_remove.push(func.loc); + } + + if let Some(ref body) = func.body { + locs_to_remove.push(body.loc()); + } + } + } + } + } + + let mut content = content.to_string(); + let mut offset = 0; + + for loc in locs_to_remove { + let start = loc.start() - offset; + let end = loc.end() - offset; + + content.replace_range(start..end, ""); + offset += end - start; + } + + let content = content.replace("\n", ""); + Ok(utils::RE_TWO_OR_MORE_SPACES.replace_all(&content, "").to_string()) +} + +/// Computes hash of [`interface_representation`] of the source. +pub(crate) fn interface_representation_hash(source: &Source) -> String { + let Ok(repr) = interface_representation(&source.content) else { return source.content_hash() }; + let mut hasher = md5::Md5::new(); + hasher.update(&repr); + let result = hasher.finalize(); + hex::encode(result) +} + +#[derive(Debug)] +pub struct ItemLocation { + start: usize, + end: usize, +} + +impl ItemLocation { + fn try_from_loc(loc: SourceLocation) -> Option { + Some(Self { start: loc.start?, end: loc.start? + loc.length? }) + } +} + +/// Checks if the given path is a test/script file. +fn is_test_or_script(path: &Path, paths: &ProjectPathsConfig) -> bool { + let test_dir = paths.tests.strip_prefix(&paths.root).unwrap_or(&paths.root); + let script_dir = paths.scripts.strip_prefix(&paths.root).unwrap_or(&paths.root); + path.starts_with(test_dir) || path.starts_with(script_dir) +} + +/// Kind of the bytecode dependency. +#[derive(Debug)] +enum BytecodeDependencyKind { + /// `type(Contract).creationCode` + CreationCode, + /// `new Contract` + New(ItemLocation, String), +} + +/// Represents a single bytecode dependency. +#[derive(Debug)] +struct BytecodeDependency { + kind: BytecodeDependencyKind, + loc: ItemLocation, + referenced_contract: usize, +} + +/// Walks over AST and collects [`BytecodeDependency`]s. +#[derive(Debug)] +struct BytecodeDependencyCollector<'a> { + source: &'a str, + dependencies: Vec, + total_count: usize, +} + +impl BytecodeDependencyCollector<'_> { + fn new(source: &str) -> BytecodeDependencyCollector<'_> { + BytecodeDependencyCollector { source, dependencies: Vec::new(), total_count: 0 } + } +} + +impl Visitor for BytecodeDependencyCollector<'_> { + fn visit_new_expression(&mut self, expr: &NewExpression) { + if let TypeName::UserDefinedTypeName(_) = &expr.type_name { + self.total_count += 1; + } + } + + fn visit_function_call(&mut self, call: &FunctionCall) { + let (new_loc, expr) = match &call.expression { + Expression::NewExpression(expr) => (expr.src, expr), + Expression::FunctionCallOptions(expr) => { + if let Expression::NewExpression(new_expr) = &expr.expression { + (expr.src, new_expr) + } else { + return; + } + } + _ => return, + }; + + let TypeName::UserDefinedTypeName(type_name) = &expr.type_name else { return }; + + let Some(loc) = ItemLocation::try_from_loc(call.src) else { return }; + let Some(name_loc) = ItemLocation::try_from_loc(type_name.src) else { return }; + let Some(new_loc) = ItemLocation::try_from_loc(new_loc) else { return }; + let name = &self.source[name_loc.start..name_loc.end]; + + self.dependencies.push(BytecodeDependency { + kind: BytecodeDependencyKind::New(new_loc, name.to_string()), + loc, + referenced_contract: type_name.referenced_declaration as usize, + }); + } + + fn visit_member_access(&mut self, access: &MemberAccess) { + if access.member_name != "creationCode" { + return; + } + self.total_count += 1; + + let Expression::FunctionCall(call) = &access.expression else { return }; + + let Expression::Identifier(ident) = &call.expression else { return }; + + if ident.name != "type" { + return; + } + + let Some(Expression::Identifier(ident)) = call.arguments.first() else { return }; + + let Some(referenced) = ident.referenced_declaration else { return }; + + let Some(loc) = ItemLocation::try_from_loc(access.src) else { return }; + + self.dependencies.push(BytecodeDependency { + kind: BytecodeDependencyKind::CreationCode, + loc, + referenced_contract: referenced as usize, + }); + } +} + +/// Keeps data about a single contract definition. +struct ContractData<'a> { + /// AST id of the contract. + ast_id: usize, + /// Path of the source file. + path: &'a Path, + /// Name of the contract + name: &'a str, + /// Constructor parameters. + constructor_params: Option<&'a ParameterList>, + /// Reference to source code. + src: &'a str, + /// Artifact string to pass into cheatcodes. + artifact: String, +} + +impl ContractData<'_> { + /// If contract has a non-empty constructor, generates a helper source file for it containing a + /// helper to encode constructor arguments. + /// + /// This is needed because current preprocessing wraps the arguments, leaving them unchanged. + /// This allows us to handle nested new expressions correctly. However, this requires us to have + /// a way to wrap both named and unnamed arguments. i.e you can't do abi.encode({arg: val}). + /// + /// This function produces a helper struct + a helper function to encode the arguments. The + /// struct is defined in scope of an abstract contract inheriting the contract containing the + /// constructor. This is done as a hack to allow us to inherit the same scope of definitions. + /// + /// The resulted helper looks like this: + /// ```solidity + /// import "lib/openzeppelin-contracts/contracts/token/ERC20.sol"; + /// + /// abstract contract DeployHelper335 is ERC20 { + /// struct ConstructorArgs { + /// string name; + /// string symbol; + /// } + /// } + /// + /// function encodeArgs335(DeployHelper335.ConstructorArgs memory args) pure returns (bytes memory) { + /// return abi.encode(args.name, args.symbol); + /// } + /// ``` + /// + /// Example usage: + /// ```solidity + /// new ERC20(name, symbol) + /// ``` + /// becomes + /// ```solidity + /// vm.deployCode("artifact path", encodeArgs335(DeployHelper335.ConstructorArgs(name, symbol))) + /// ``` + /// With named arguments: + /// ```solidity + /// new ERC20({name: name, symbol: symbol}) + /// ``` + /// becomes + /// ```solidity + /// vm.deployCode("artifact path", encodeArgs335(DeployHelper335.ConstructorArgs({name: name, symbol: symbol}))) + /// ``` + pub fn build_helper(&self) -> Result> { + let Self { ast_id, path, name, constructor_params, src, artifact } = self; + + let Some(params) = constructor_params else { return Ok(None) }; + + let struct_fields = params + .parameters + .iter() + .filter_map(|param| { + let loc = ItemLocation::try_from_loc(param.src)?; + Some(src[loc.start..loc.end].replace(" memory ", " ").replace(" calldata ", " ")) + }) + .join("; "); + + let abi_encode_args = + params.parameters.iter().map(|param| format!("args.{}", param.name)).join(", "); + + let vm_interface_name = format!("VmContractHelper{}", ast_id); + let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)"); + + let helper = format!( + r#" +pragma solidity >=0.4.0; + +import "{path}"; + +abstract contract DeployHelper{ast_id} is {name} {{ + struct ConstructorArgs {{ + {struct_fields}; + }} +}} + +function encodeArgs{ast_id}(DeployHelper{ast_id}.ConstructorArgs memory args) pure returns (bytes memory) {{ + return abi.encode({abi_encode_args}); +}} + +function deployCode{ast_id}(DeployHelper{ast_id}.ConstructorArgs memory args) returns({name}) {{ + return {name}(payable({vm}.deployCode("{artifact}", encodeArgs{ast_id}(args)))); +}} + +interface {vm_interface_name} {{ + function deployCode(string memory _artifact, bytes memory _data) external returns (address); + function deployCode(string memory _artifact) external returns (address); + function getCode(string memory _artifact) external returns (bytes memory); +}} + "#, + path = path.display(), + ); + + Ok(Some(helper)) + } +} + +/// Receives a set of source files along with their ASTs and removes bytecode dependencies from +/// contracts by replacing them with cheatcode invocations. +struct BytecodeDependencyOptimizer<'a> { + asts: BTreeMap, + paths: &'a ProjectPathsConfig, + sources: &'a mut Sources, +} + +impl BytecodeDependencyOptimizer<'_> { + fn new<'a>( + asts: BTreeMap, + paths: &'a ProjectPathsConfig, + sources: &'a mut Sources, + ) -> BytecodeDependencyOptimizer<'a> { + BytecodeDependencyOptimizer { asts, paths, sources } + } + + fn process(self) -> Result<()> { + let mut updates = Updates::default(); + + let contracts = self.collect_contracts(); + let additional_sources = self.create_deploy_helpers(&contracts)?; + self.remove_bytecode_dependencies(&contracts, &mut updates)?; + + self.sources.extend(additional_sources); + + apply_updates(self.sources, updates); + + Ok(()) + } + + /// Collects a mapping from contract AST id to [ContractData] for all contracts defined in src/ + /// files. + fn collect_contracts(&self) -> BTreeMap> { + let mut contracts = BTreeMap::new(); + + for (path, ast) in &self.asts { + let src = self.sources.get(path).unwrap().content.as_str(); + + if is_test_or_script(path, &self.paths) { + continue; + } + + for node in &ast.nodes { + if let SourceUnitPart::ContractDefinition(contract) = node { + let artifact = format!("{}:{}", path.display(), contract.name); + let constructor = contract.nodes.iter().find_map(|node| { + let ContractDefinitionPart::FunctionDefinition(func) = node else { + return None; + }; + if *func.kind() != FunctionKind::Constructor { + return None; + } + + Some(func) + }); + + contracts.insert( + contract.id, + ContractData { + artifact, + constructor_params: constructor + .map(|constructor| &constructor.parameters) + .filter(|params| !params.parameters.is_empty()), + src, + ast_id: contract.id, + path, + name: &contract.name, + }, + ); + } + } + } + + contracts + } + + /// Creates helper libraries for contracts with a non-empty constructor. + /// + /// See [`ContractData::build_helper`] for more details. + fn create_deploy_helpers( + &self, + contracts: &BTreeMap>, + ) -> Result { + let mut new_sources = Sources::new(); + for (id, contract) in contracts { + if let Some(code) = contract.build_helper()? { + let path = format!("foundry-pp/DeployHelper{}.sol", id); + new_sources.insert(path.into(), Source::new(code)); + } + } + + Ok(new_sources) + } + + /// Goes over all test/script files and replaces bytecode dependencies with cheatcode + /// invocations. + fn remove_bytecode_dependencies( + &self, + contracts: &BTreeMap>, + updates: &mut Updates, + ) -> Result<()> { + for (path, ast) in &self.asts { + if !is_test_or_script(path, &self.paths) { + continue; + } + let src = self.sources.get(path).unwrap().content.as_str(); + + if src.is_empty() { + continue; + } + + let updates = updates.entry(path.clone()).or_default(); + let mut used_helpers = BTreeSet::new(); + + let mut collector = BytecodeDependencyCollector::new(src); + ast.walk(&mut collector); + + // It is possible to write weird expressions which we won't catch. + // e.g. (((new Contract)))() is valid syntax + // + // We need to ensure that we've collected all dependencies that are in the contract. + if collector.dependencies.len() != collector.total_count { + return Err(SolcError::msg(format!( + "failed to collect all bytecode dependencies for {}", + path.display() + ))); + } + + let vm_interface_name = format!("VmContractHelper{}", ast.id); + let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)"); + + for dep in collector.dependencies { + let Some(ContractData { artifact, constructor_params, .. }) = + contracts.get(&dep.referenced_contract) + else { + continue; + }; + match dep.kind { + BytecodeDependencyKind::CreationCode => { + // for creation code we need to just call getCode + updates.insert(( + dep.loc.start, + dep.loc.end, + format!("{vm}.getCode(\"{artifact}\")"), + )); + } + BytecodeDependencyKind::New(new_loc, name) => { + if constructor_params.is_none() { + // if there's no constructor, we can just call deployCode with one + // argument + updates.insert(( + dep.loc.start, + dep.loc.end, + format!("{name}(payable({vm}.deployCode(\"{artifact}\")))"), + )); + } else { + // if there's a constructor, we use our helper + used_helpers.insert(dep.referenced_contract); + updates.insert(( + new_loc.start, + new_loc.end, + format!("deployCode{id}(DeployHelper{id}.ConstructorArgs", id = dep.referenced_contract), + )); + updates.insert((dep.loc.end, dep.loc.end, ")".to_string())); + } + } + }; + } + let helper_imports = used_helpers.into_iter().map(|id| { + format!( + "import {{DeployHelper{id}, encodeArgs{id}, deployCode{id}}} from \"foundry-pp/DeployHelper{id}.sol\";", + ) + }).join("\n"); + updates.insert(( + src.len(), + src.len(), + format!( + r#" +{helper_imports} + +interface {vm_interface_name} {{ + function deployCode(string memory _artifact, bytes memory _data) external returns (address); + function deployCode(string memory _artifact) external returns (address); + function getCode(string memory _artifact) external returns (bytes memory); +}}"# + ), + )); + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct TestOptimizerPreprocessor; + +impl Preprocessor for TestOptimizerPreprocessor { + fn preprocess( + &self, + solc: &SolcCompiler, + mut input: SolcVersionedInput, + paths: &ProjectPathsConfig, + ) -> Result { + // Skip if we are not compiling any tests or scripts. Avoids unnecessary solc invocation and + // AST parsing. + if input.input.sources.iter().all(|(path, _)| !is_test_or_script(path, paths)) { + return Ok(input); + } + + let prev_output_selection = std::mem::replace( + &mut input.input.settings.output_selection, + OutputSelection::ast_output_selection(), + ); + let output = solc.compile(&input)?; + + input.input.settings.output_selection = prev_output_selection; + + if let Some(e) = output.errors.iter().find(|e| e.severity.is_error()) { + return Err(SolcError::msg(e)); + } + + let asts = output + .sources + .into_iter() + .filter_map(|(path, source)| { + if !input.input.sources.contains_key(&path) { + return None; + } + + Some((|| { + let ast = source.ast.ok_or_else(|| SolcError::msg("missing AST"))?; + let ast: SourceUnit = serde_json::from_str(&serde_json::to_string(&ast)?)?; + Ok((path, ast)) + })()) + }) + .collect::>>()?; + + BytecodeDependencyOptimizer::new(asts, paths, &mut input.input.sources).process()?; + + Ok(input) + } +} + +impl Preprocessor for TestOptimizerPreprocessor { + fn preprocess( + &self, + compiler: &MultiCompiler, + input: ::Input, + paths: &ProjectPathsConfig, + ) -> Result<::Input> { + match input { + MultiCompilerInput::Solc(input) => { + if let Some(solc) = &compiler.solc { + let paths = paths.clone().with_language::(); + let input = self.preprocess(solc, input, &paths)?; + Ok(MultiCompilerInput::Solc(input)) + } else { + Ok(MultiCompilerInput::Solc(input)) + } + } + MultiCompilerInput::Vyper(input) => Ok(MultiCompilerInput::Vyper(input)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_interface_representation() { + let content = r#" +library Lib { + function libFn() internal { + // logic to keep + } +} +contract A { + function a() external {} + function b() public {} + function c() internal { + // logic logic logic + } + function d() private {} + function e() external { + // logic logic logic + } +}"#; + + let result = interface_representation(content).unwrap(); + assert_eq!( + result, + r#"library Lib {function libFn() internal {// logic to keep}}contract A {function a() externalfunction b() publicfunction e() external }"# + ); + } +} diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index aa9e5221..db78648c 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -5,7 +5,7 @@ use std::{ }; use thiserror::Error; -pub type Result = std::result::Result; +pub type Result = std::result::Result; #[allow(unused_macros)] #[macro_export] @@ -70,6 +70,9 @@ pub enum SolcError { #[error("no artifact found for `{}:{}`", .0.display(), .1)] ArtifactNotFound(PathBuf, String), + #[error(transparent)] + Fmt(#[from] std::fmt::Error), + #[cfg(feature = "project-util")] #[error(transparent)] FsExtra(#[from] fs_extra::error::Error), diff --git a/crates/core/src/utils.rs b/crates/core/src/utils.rs index 877d2d9f..17e80775 100644 --- a/crates/core/src/utils.rs +++ b/crates/core/src/utils.rs @@ -42,6 +42,9 @@ pub static RE_SOL_SDPX_LICENSE_IDENTIFIER: Lazy = /// A regex used to remove extra lines in flatenned files pub static RE_THREE_OR_MORE_NEWLINES: Lazy = Lazy::new(|| Regex::new("\n{3,}").unwrap()); +/// A regex used to remove extra lines in flatenned files +pub static RE_TWO_OR_MORE_SPACES: Lazy = Lazy::new(|| Regex::new(" {2,}").unwrap()); + /// A regex that matches version pragma in a Vyper pub static RE_VYPER_VERSION: Lazy = Lazy::new(|| Regex::new(r"#(?:pragma version|@version)\s+(?P.+)").unwrap());