diff --git a/spec-trait-impl/crates/spec-trait-bin/src/main.rs b/spec-trait-impl/crates/spec-trait-bin/src/main.rs index 0409e63..5294b91 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/main.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/main.rs @@ -280,6 +280,7 @@ fn main() { let s2 = S(1u8); // ZST - Foo + spec! { zst.foo(1u8); ZST; [MyType] } // -> "Foo impl ZST where T is MyType" spec! { zst.foo(1u8); ZST; [u8]; u8 = MyType } // -> "Foo impl ZST where T is MyType" spec! { zst.foo(vec![1i32]); ZST; [Vec]; Vec = MyVecAlias } // -> "Foo impl ZST where T is MyVecAlias" spec! { zst.foo(vec![1u8]); ZST; [Vec]; u8 = MyType } // -> "Foo impl ZST where T is Vec" diff --git a/spec-trait-impl/crates/spec-trait-macro/src/annotations.rs b/spec-trait-impl/crates/spec-trait-macro/src/annotations.rs index a41f51e..b1c6e55 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/annotations.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/annotations.rs @@ -7,8 +7,11 @@ use syn::{Error, Expr, Ident, Lit, Token, Type, bracketed, parenthesized, token} #[derive(Debug, PartialEq, Clone)] pub enum Annotation { + /// `: + + ...` Trait(String /* type */, Vec /* traits */), + /// ` = ` Alias(String /* type */, String /* alias */), + /// `: '` Lifetime(String /* type */, String /* lifetime */), } @@ -79,6 +82,10 @@ impl Parse for AnnotationBody { } } +/// Parse the method call of the form `variable.function(args...)` +/// # Example: +/// - `x.my_method(1u8, "abc")` -> `("x", "my_method", ["1u8", "abc"])` +/// - `var.foo()` -> `("var", "foo", [])` fn parse_call(input: ParseStream) -> Result<(String, String, Vec), Error> { let var = if input.peek(Ident) { to_string(&input.parse::()?) @@ -104,6 +111,11 @@ fn parse_call(input: ParseStream) -> Result<(String, String, Vec), Error Ok((var, fn_.to_string(), args.iter().map(to_string).collect())) } +/// Parse the variable type and argument types +/// # Example: +/// - `MyType; [u8, &str]` -> `("MyType", ["u8", "&str"])` +/// - `MyType; []` -> `("MyType", [])` +/// - `MyType` -> `("MyType", [])` fn parse_types(input: ParseStream) -> Result<(String, Vec), Error> { let var_type: Type = input.parse()?; @@ -131,6 +143,7 @@ fn parse_types(input: ParseStream) -> Result<(String, Vec), Error> { Ok((to_string(&var_type), args_types)) } +/// Parse the annotations fn parse_annotations(input: ParseStream) -> Result, Error> { input .parse_terminated(Annotations::parse, Token![;]) diff --git a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs index aab55a0..6067b94 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs @@ -34,6 +34,7 @@ impl TryFrom<(&Vec, &Vec, &AnnotationBody)> for SpecBody { .filter_map(|impl_| { let trait_ = traits.iter().find(|tr| tr.name == impl_.trait_name)?; + // ensure that the generics in the trait and impl match if get_generics_types::>(&trait_.generics).len() != get_generics_types::>(&impl_.trait_generics).len() || get_generics_lifetimes::>(&trait_.generics).len() @@ -99,7 +100,7 @@ fn get_constraints(default: SpecBody) -> Option { None => Some(default), // from when macro Some(cond) => { - let var = VarBody::from(&default); // TODO: handle conflicting vars + let var = VarBody::from(&default); // TODO: handle conflicting vars (early return None) let (satisfied, constraints) = satisfies_condition(cond, &var, &default.constraints); if satisfied { @@ -113,6 +114,8 @@ fn get_constraints(default: SpecBody) -> Option { } } +/// checks if the given condition is satisfied by the given var info and constraints +/// if satisfied, it returns the updated constraints fn satisfies_condition( condition: &WhenCondition, var: &VarBody, @@ -262,6 +265,8 @@ fn satisfies_condition( } } +/// fills the constraints for each generic parameter from trait and specialized trait, +/// based on the constraints from impl and type fn fill_trait_constraints(constraints: &Constraints, spec: &SpecBody) -> Constraints { let mut constraints = constraints.clone(); constraints.from_trait = get_trait_constraints(&spec.trait_, &spec.impl_, &constraints); @@ -278,6 +283,7 @@ fn fill_trait_constraints(constraints: &Constraints, spec: &SpecBody) -> Constra constraints } +/// gets the constraints for each generic parameter from trait based on the constraints from impl fn get_trait_constraints( trait_: &TraitBody, impl_: &ImplBody, @@ -324,6 +330,7 @@ impl From<&SpecBody> for TokenStream { } } +/// gets the types for the generics in the spec body pub fn get_types_for_generics(spec: &SpecBody) -> TokenStream { let trait_body = spec .trait_ @@ -344,6 +351,7 @@ pub fn get_types_for_generics(spec: &SpecBody) -> TokenStream { } } +/// gets the type for a generic parameter based on the constraints fn get_type(generic: &str, constraints: &Constraints) -> String { constraints .from_impl diff --git a/spec-trait-impl/crates/spec-trait-macro/src/vars.rs b/spec-trait-impl/crates/spec-trait-macro/src/vars.rs index 085ff6a..8860c34 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/vars.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/vars.rs @@ -48,6 +48,7 @@ impl From<&SpecBody> for VarBody { } } +/// Get type aliases from annotations pub fn get_type_aliases(ann: &[Annotation]) -> Aliases { let mut aliases = Aliases::new(); @@ -63,6 +64,7 @@ pub fn get_type_aliases(ann: &[Annotation]) -> Aliases { aliases } +/// Get variable information for each generic in the impl fn get_vars( ann: &AnnotationBody, impl_: &ImplBody, @@ -81,7 +83,7 @@ fn get_vars( ); match trait_.get_corresponding_generic(&str_to_generics(&impl_.trait_generics), g) { - // get type + // get type from trait (generic already existed before specialization) Some(trait_generic) => { let from_trait = get_generic_constraints_from_trait( &trait_generic, @@ -94,7 +96,7 @@ fn get_vars( from_trait.into_iter().chain(from_type).collect::>() } - // get from specialized instead + // get from specialized trait (generic was added during specialization) None => { let trait_generic = trait_ .specialized @@ -136,7 +138,7 @@ fn get_vars( /** Get the parameter types from a trait function. # Example - `fn foo(&self, x: T, y: u32);` returns `vec!["T", "u32"]` + - `fn foo(&self, x: T, y: u32);` -> `vec!["T", "u32"]` */ fn get_param_types(trait_fn: &TraitItemFn) -> Vec { trait_fn diff --git a/spec-trait-impl/crates/spec-trait-order/src/aliases.rs b/spec-trait-impl/crates/spec-trait-order/src/aliases.rs index b1f65d8..e9acc48 100644 --- a/spec-trait-impl/crates/spec-trait-order/src/aliases.rs +++ b/spec-trait-impl/crates/spec-trait-order/src/aliases.rs @@ -4,6 +4,7 @@ use syn::{Item, Path, UseTree}; const MACRO_PACKAGE: &str = "spec_trait_macro"; const MACRO_NAME: &str = "when"; +/// Collect all aliases for the `when` macro from the given items. pub fn collect_when_aliases(items: &[Item]) -> HashSet { let mut set = HashSet::new(); @@ -16,6 +17,7 @@ pub fn collect_when_aliases(items: &[Item]) -> HashSet { set } +/// Recursively traverse the `UseTree` to find aliases for the `when` macro. fn collect_aliases_from_tree(tree: &UseTree, prefix_spec: bool, set: &mut HashSet) { match tree { // `use spec_trait_macro::...` @@ -50,6 +52,8 @@ fn collect_aliases_from_tree(tree: &UseTree, prefix_spec: bool, set: &mut HashSe } } +/// Check if the given path corresponds to the `when` macro, +/// considering both direct usage and aliases collected earlier. pub fn is_when_macro(path: &Path, when_aliases: &HashSet) -> bool { // `when` imported directly or via alias when_aliases.contains(&path.segments.last().unwrap().ident.to_string()) || diff --git a/spec-trait-impl/crates/spec-trait-order/src/crates.rs b/spec-trait-impl/crates/spec-trait-order/src/crates.rs index 55983b0..3af6ffa 100644 --- a/spec-trait-impl/crates/spec-trait-order/src/crates.rs +++ b/spec-trait-impl/crates/spec-trait-order/src/crates.rs @@ -6,8 +6,11 @@ use std::path::{Path, PathBuf}; #[derive(Debug)] pub struct Crate { + /// The name of the crate pub name: String, + /// The parsed content of the crate pub content: CrateCache, + /// The list of source files paths in the crate #[cfg(test)] files: Vec, } @@ -27,6 +30,7 @@ pub fn get_crates(dir: &Path) -> Vec { .collect() } +/// Extract crate information from the `[package]` section of Cargo.toml fn get_crate_from_package(value: &toml::Value, dir: &Path) -> Option { let package = value.get("package")?; let name = package.get("name")?.as_str()?; @@ -43,6 +47,8 @@ fn get_crate_from_package(value: &toml::Value, dir: &Path) -> Option { }) } +/// Extract information from the `[workspace]` members section of Cargo.toml +/// and for each member, gather crate information fn get_crates_from_workspace_members(value: &toml::Value, dir: &Path) -> Vec { let members = value .get("workspace") diff --git a/spec-trait-impl/crates/spec-trait-utils/src/cache.rs b/spec-trait-impl/crates/spec-trait-utils/src/cache.rs index ac6ddaa..b7e36f2 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/cache.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/cache.rs @@ -41,6 +41,7 @@ pub fn write_cache(cache: &CrateCache, crate_name: Option) { write_top_level_cache(&top_level_cache); } +/// Initialize an empty cache pub fn reset() { let empty_cache = Cache::new(); write_top_level_cache(&empty_cache); @@ -65,11 +66,13 @@ pub fn add_impl(imp: ImplBody) { write_cache(&cache, None); } +/// Get the trait with the given name pub fn get_trait_by_name(trait_name: &str) -> Option { let cache = read_cache(None); cache.traits.into_iter().find(|tr| tr.name == trait_name) } +/// Get all traits that have a function with the given name and number of arguments pub fn get_traits_by_fn(fn_name: &str, args_len: usize) -> Vec { let cache = read_cache(None); cache @@ -79,6 +82,7 @@ pub fn get_traits_by_fn(fn_name: &str, args_len: usize) -> Vec { .collect() } +/// Get all impls for the given type that implement any of the given traits pub fn get_impls_by_type_and_traits( type_name: &str, traits: &[TraitBody], diff --git a/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs b/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs index 1a0808a..da65180 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs @@ -9,13 +9,18 @@ use syn::{Error, Ident, Token, parenthesized}; #[derive(Serialize, Deserialize, Debug, Clone, Eq)] pub enum WhenCondition { + /// ` = ` Type( String, /* generic */ String, /* type (without lifetime) */ ), + /// `: + + ...` Trait(String /* generic */, Vec /* traits */), + /// `all(, , ...)` All(Vec), + /// `any(, , ...)` Any(Vec), + /// `not()` Not(Box), } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs index 034f4c5..18e8f6f 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/impls.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/impls.rs @@ -229,6 +229,8 @@ impl ImplBody { self.impl_generics = to_string(&generics); } + // WhenCondition::Any should never appear here + // WhenCondition::Not can't be applied directly _ => {} } } @@ -239,6 +241,7 @@ impl ImplBody { for trait `TraitName` and impl `impl TraitName for MyType` - trait_generic = A -> trait_generic = T - trait_generic = B -> trait_generic = U + - trait_generic = C -> None */ pub fn get_corresponding_generic( &self, diff --git a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs index 99f0885..356685b 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/specialize.rs @@ -22,6 +22,9 @@ pub trait Specializable { fn handle_items_visit Visit<'a>>(&self, visitor: &mut V); } +/// Filters the given conditions to return only those that are assignable. +/// If there are multiple type conditions for the same generic, only the most specific one is kept. +/// If there are conflicting types that are not mutually assignable, they are not included. pub fn get_assignable_conditions( conditions: &[WhenCondition], generics: &str, @@ -38,6 +41,7 @@ pub fn get_assignable_conditions( && !type_assignable(other_t, t, generics, &Aliases::default()) }); + // if there are no conflicts, keep only the most specific type if diff_types || !most_specific { None } else { @@ -49,6 +53,8 @@ pub fn get_assignable_conditions( .collect() } +/// Returns the list of types assigned to the given generic in the conditions, +/// ordered from least specific to most specific. fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition]) -> Vec { let mut types = conditions .iter() @@ -57,6 +63,7 @@ fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition] _ => None, }) .collect::>(); + // `Vec<_>` is less specific than `Vec` types.sort_by_key(|t| t.replace("_", "").len()); types } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs index 45ee2a4..07e57b4 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs @@ -242,11 +242,12 @@ impl TraitBody { } /** - get the generic in the trait corresponding to the impl_generic in the impl - # Example: - for trait `TraitName` and impl `impl TraitName for MyType` - - impl_generic = T -> trait_generic = A - - impl_generic = U -> trait_generic = B + get the generic in the trait corresponding to the impl_generic in the impl + # Example: + for trait `TraitName` and impl `impl TraitName for MyType` + - impl_generic = T -> trait_generic = A + - impl_generic = U -> trait_generic = B + - impl_generic = C -> None */ pub fn get_corresponding_generic( &self, diff --git a/spec-trait-inst/src/analysis/sti_analysis/compare.rs b/spec-trait-inst/src/analysis/sti_analysis/compare.rs index 60cb2b2..2c4eec5 100644 --- a/spec-trait-inst/src/analysis/sti_analysis/compare.rs +++ b/spec-trait-inst/src/analysis/sti_analysis/compare.rs @@ -17,6 +17,7 @@ pub struct STIComparison<'tcx, 'a> { pub struct SpecializablePair<'a> { pub fn_a: &'a VisitedFn, pub fn_b: &'a VisitedFn, + /// The similarity score between the two functions pub similarity: f64, } @@ -24,6 +25,7 @@ pub struct SpecializablePair<'a> { pub struct SpecializableTraitPair<'a> { pub trait_a: &'a VisitedTrait, pub trait_b: &'a VisitedTrait, + /// The similarity score between the two traits pub similarity: f64, } @@ -49,8 +51,11 @@ impl Display for SpecializableTraitPair<'_> { #[derive(Debug)] struct TraitInfo<'a> { + /// The trait tr: &'a VisitedTrait, + /// Number of function items in the trait fn_items: usize, + /// Number of specializable function items in the trait specializable_fn_items: usize, } @@ -94,7 +99,7 @@ impl<'tcx, 'a> STIComparison<'tcx, 'a> { specializable } - /// Get pairs of specializable functions along with their similarity score. + /// Get pairs of specializable traits along with their similarity score. pub fn get_specializable_traits( &self, specializable_fn_pairs: &Vec<&SpecializablePair<'a>>, @@ -142,6 +147,7 @@ impl<'tcx, 'a> STIComparison<'tcx, 'a> { .and_then(|tr| trait_info.get(&tr)), ) { if specializable_trait_pair_heuristic(trait_a, trait_b) { + // similarity is the minimum similarity among all function pairs if let Some(similarity) = traits.get_mut(&key(trait_a, trait_b)) { *similarity = (*similarity).min(pair.similarity); } else { @@ -170,7 +176,8 @@ impl<'tcx, 'a> STIComparison<'tcx, 'a> { let max_size = a_size.max(b_size) as f64; let min_size = a_size.min(b_size) as f64; - // we would at least need to remove (max_size - min_size) nodes + // return early if the size difference is too big, + // since we would at least need to remove (max_size - min_size) nodes if min_size / max_size < THRESHOLD { return None; } @@ -190,6 +197,7 @@ impl<'tcx, 'a> STIComparison<'tcx, 'a> { .iter() .filter_map(|visited_fn| { let tree = BodyTree::new(self.analyzer, &visited_fn.body); + // only keep functions with non-empty bodies if tree.root.children.is_empty() { None } else { @@ -201,6 +209,8 @@ impl<'tcx, 'a> STIComparison<'tcx, 'a> { } /// A heuristic to quickly determine if two functions could be specializable. +/// This checks the number of arguments, the similarity of names, +/// the parent trait (if any), and the self type (if any). fn specializable_fn_heuristic(fn_a: &VisitedFn, fn_b: &VisitedFn) -> bool { fn_a.args.len() == fn_b.args.len() && names_similar(&fn_a.name, &fn_b.name) @@ -217,11 +227,14 @@ fn specializable_fn_heuristic(fn_a: &VisitedFn, fn_b: &VisitedFn) -> bool { } /// A heuristic to quickly determine if two traits could be specializable. +/// This checks the number of items and the similarity of names. fn specializable_trait_heuristic(tr_a: &VisitedTrait, tr_b: &VisitedTrait) -> bool { tr_a.items == tr_b.items && names_similar(&tr_a.name, &tr_b.name) } /// A heuristic to quickly determine if two traits could be specializable. +/// This builds upon `specializable_trait_heuristic` and also checks +/// the number of function items and specializable function items. fn specializable_trait_pair_heuristic<'a>(tr_a: &TraitInfo<'a>, tr_b: &TraitInfo<'a>) -> bool { let max_diff = |a: usize, b: usize| a.abs_diff(b) <= 1; specializable_trait_heuristic(tr_a.tr, tr_b.tr) @@ -234,6 +247,8 @@ fn specializable_trait_pair_heuristic<'a>(tr_a: &TraitInfo<'a>, tr_b: &TraitInfo } /// A heuristic to quickly determine if two functions could be already specializable. +/// This builds upon `specializable_fn_heuristic` and also checks that the self types +/// are the same and that there is at least one differing argument type that is not generic. pub fn already_specializable_heuristic(pair: &SpecializablePair) -> bool { specializable_fn_heuristic(pair.fn_a, pair.fn_b) && pair.fn_a.self_type == pair.fn_b.self_type @@ -249,6 +264,8 @@ pub fn already_specializable_heuristic(pair: &SpecializablePair) -> bool { } /// A heuristic to quickly determine if two functions are already specialized. +/// This builds upon `already_specializable_heuristic` and also checks that +/// the function IDs are the same and that they belong to the same trait (if any). pub fn already_specialized_heuristic(pair: &SpecializablePair) -> bool { already_specializable_heuristic(pair) && pair.fn_a.id == pair.fn_b.id @@ -260,6 +277,7 @@ pub fn already_specialized_heuristic(pair: &SpecializablePair) -> bool { }) } +/// Check if a type string contains any of the given generics. fn contains_generics(ty: &str, generics: &[String]) -> bool { generics.iter().filter(|gen| !gen.contains("$")).any(|gen| { if gen.starts_with("'") { diff --git a/spec-trait-inst/src/analysis/sti_analysis/distance.rs b/spec-trait-inst/src/analysis/sti_analysis/distance.rs index adfbba3..7e86e76 100644 --- a/spec-trait-inst/src/analysis/sti_analysis/distance.rs +++ b/spec-trait-inst/src/analysis/sti_analysis/distance.rs @@ -30,6 +30,7 @@ type TreeDistanceTable = HashMap<(usize, usize), usize>; type ForestDistanceTable = Vec>; /// Computes the similarity between two trees based on their edit distance. +/// The similarity is defined as 1 - (distance / max_size), where max_size is the size of the larger tree. pub fn compute_similarity(a: &Node, b: &Node) -> f64 { let dist = compute_tree_distance(a, b) as f64; let max_size = a.size().max(b.size()) as f64; @@ -96,7 +97,7 @@ fn post_order_and_leftmost<'a, T: Eq + Debug>(root: &'a Node) -> PostOrder<'a PostOrder { nodes, leftmost } } -/// Collect nodes in post order. +/// Collect nodes in post order (children first, then node). fn collect_post_order<'a, T: Eq + Debug>(n: &'a Node, out: &mut Vec<&'a Node>) { for c in &n.children { collect_post_order(c, out); @@ -142,6 +143,7 @@ fn compute_key_root_distance( let left_b_i = post_b.leftmost[post_b_i]; // leftmost of current index in B // both node_a and node_b are roots of their respective (sub)trees in this forest + // so we can compute the tree distance directly if left_a_i == kr_a_leftmost && left_b_i == kr_b_leftmost { let node_a = post_a.nodes[post_a_i]; // node at current index in A let node_b = post_b.nodes[post_b_i]; // node at current index in B @@ -170,6 +172,7 @@ fn compute_key_root_distance( } /// Initializes the forest distance table with base cases. +/// The first row and column represent the cost of converting to/from an empty tree. fn init_forest_distance(n: usize, m: usize) -> ForestDistanceTable { let mut fd = vec![vec![0; m + 1]; n + 1]; fd.iter_mut().enumerate().for_each(|(i, row)| row[0] = i); // deletion diff --git a/spec-trait-inst/src/analysis/sti_analysis/mod.rs b/spec-trait-inst/src/analysis/sti_analysis/mod.rs index 43f2a01..ebab924 100644 --- a/spec-trait-inst/src/analysis/sti_analysis/mod.rs +++ b/spec-trait-inst/src/analysis/sti_analysis/mod.rs @@ -7,6 +7,7 @@ use super::Analyzer; use compare::{STIComparison, SpecializablePair, SpecializableTraitPair}; use itertools::Itertools; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::{ItemId, ItemKind}; use std::{cell::Cell, time::Duration}; use sti_visitor::{STIVisitor, VisitedFn, VisitedTrait}; @@ -27,7 +28,7 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { } } - fn visit_item(&self, visitor: &mut STIVisitor<'tcx, 'a>, item_id: &rustc_hir::ItemId) { + fn visit_item(&self, visitor: &mut STIVisitor<'tcx, 'a>, item_id: &ItemId) { let hir_id = self .analyzer .tcx @@ -35,13 +36,14 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { let item = self.analyzer.tcx.hir_item(*item_id); visitor.visit_with_hir_id_and_item(hir_id, item); - if let rustc_hir::ItemKind::Mod(_, module) = &item.kind { + if let ItemKind::Mod(_, module) = &item.kind { for sub in module.item_ids { self.visit_item(visitor, sub); } } } + /// Analyze a group of specializable function pairs to extract trait pairs and unique functions and traits. fn analyze_group( &self, comparison: &STIComparison<'tcx, 'a>, @@ -72,6 +74,12 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { ) } + /// For each group (specializable, already specializable, already specialized, newly specializable, + /// already specializable but not specialized), returns: + /// - number of specializable function pairs + /// - number of specializable trait pairs + /// - number of specializable functions + /// - number of specializable traits fn get_groups( &self, comparison: &STIComparison<'tcx, 'a>, @@ -145,6 +153,8 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { ] } + /// Analyze the visited functions for specialization targets. + /// Prints the results if the corresponding CLI flags are set. fn analyze(&self, visitor: &STIVisitor<'tcx, 'a>) { let comparison = STIComparison::new(self.analyzer, visitor); let specializable_pairs = comparison.get_specializable_fns(); @@ -168,7 +178,7 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { } if self.analyzer.cli_args.print_spec_data { - // groups + // groups: total, bare fns, trait fns, trait impl fns, inherent impl fns, macro fns let is_bare_fn = |f: &VisitedFn| f.parent_trait.is_none() && f.self_type.is_none(); let specializable_pairs_from_bare_fns = specializable_pairs .iter() @@ -237,7 +247,7 @@ impl<'tcx, 'a> STIAnalysis<'tcx, 'a> { .filter(|f| is_macro_fn(f)) .count(); - // groups values + // groups values: total, specializable, already specializable, already specialized, newly specializable, already specializable but not specialized let base_group = self.get_groups(&comparison, &specializable_pairs.iter().collect()); let bare_fn_group = self.get_groups(&comparison, &specializable_pairs_from_bare_fns); let trait_group = self.get_groups(&comparison, &specializable_pairs_from_traits); diff --git a/spec-trait-inst/src/analysis/sti_analysis/sti_visitor.rs b/spec-trait-inst/src/analysis/sti_analysis/sti_visitor.rs index 235f499..5f5f048 100644 --- a/spec-trait-inst/src/analysis/sti_analysis/sti_visitor.rs +++ b/spec-trait-inst/src/analysis/sti_analysis/sti_visitor.rs @@ -22,13 +22,19 @@ pub struct VisitedTrait { #[derive(Debug)] pub struct VisitedFn { + /// Name of the function pub name: String, + /// DefId of the function pub id: DefId, + /// Argument types as Strings pub args: Vec, + /// BodyId of the function pub body: BodyId, /// all the generics that might be used in the fn (both fn generics and impl / trait generics) pub generics: Vec, + /// If the function is part of a trait impl or a trait, this contains the trait info pub parent_trait: Option, + /// If the function is part of a trait impl or an inherent impl, this contains the self type info pub self_type: Option, } @@ -168,7 +174,6 @@ impl<'tcx> NestedFilter<'tcx> for NestedFilterAll { const INTRA: bool = true; } -// NOTE(bruzzone): `visit_ty_unambig` and `visit_const_arg_unambig` are defined in VisitorExt, so we need to import it. impl<'tcx> Visitor<'tcx> for STIVisitor<'tcx, '_> { type NestedFilter = NestedFilterAll; @@ -185,6 +190,7 @@ impl<'tcx> Visitor<'tcx> for STIVisitor<'tcx, '_> { } fn visit_item(&mut self, item: &'tcx Item) { + // bare function if let ItemKind::Fn { ident, sig, @@ -202,6 +208,7 @@ impl<'tcx> Visitor<'tcx> for STIVisitor<'tcx, '_> { } fn visit_impl_item(&mut self, item: &'tcx ImplItem) { + // trait impl or inherent impl function if let ImplItemKind::Fn(sig, body) = item.kind { let impl_id = self .analyzer @@ -239,6 +246,7 @@ impl<'tcx> Visitor<'tcx> for STIVisitor<'tcx, '_> { } fn visit_trait_item(&mut self, item: &'tcx TraitItem) { + // trait function with default implementation if let TraitItemKind::Fn(sig, TraitFn::Provided(body)) = item.kind { log::trace!("Visiting trait fn: {}", item.ident); diff --git a/spec-trait-inst/src/analysis/sti_analysis/tree.rs b/spec-trait-inst/src/analysis/sti_analysis/tree.rs index 8eff9f2..97d4be7 100644 --- a/spec-trait-inst/src/analysis/sti_analysis/tree.rs +++ b/spec-trait-inst/src/analysis/sti_analysis/tree.rs @@ -116,6 +116,7 @@ pub struct BodyTree<'tcx, 'a> { } impl<'tcx, 'a> BodyTree<'tcx, 'a> { + /// Create a new BodyTree by visiting the body with the given BodyId pub fn new(analyzer: &'a Analyzer<'tcx>, id: &BodyId) -> Self { let mut tree = Self { analyzer, @@ -128,6 +129,7 @@ impl<'tcx, 'a> BodyTree<'tcx, 'a> { tree.visit_body_expr(id); + // if the root has a single Block child, promote its children to root if tree.root.children.len() == 1 && tree.root.children[0].label == NodeKind::Block { tree.root = tree.root.children.remove(0); }