Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec-trait-impl/crates/spec-trait-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ fn main() {
let s2 = S(1u8);

// ZST - Foo
spec! { zst.foo(1u8); ZST; [MyType] } // -> "Foo impl ZST where T is MyType"
spec! { zst.foo(1u8); ZST; [u8]; u8 = MyType } // -> "Foo impl ZST where T is MyType"
spec! { zst.foo(vec![1i32]); ZST; [Vec<i32>]; Vec<i32> = MyVecAlias } // -> "Foo impl ZST where T is MyVecAlias"
spec! { zst.foo(vec![1u8]); ZST; [Vec<u8>]; u8 = MyType } // -> "Foo impl ZST where T is Vec<u8>"
Expand Down
13 changes: 13 additions & 0 deletions spec-trait-impl/crates/spec-trait-macro/src/annotations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ use syn::{Error, Expr, Ident, Lit, Token, Type, bracketed, parenthesized, token}

#[derive(Debug, PartialEq, Clone)]
pub enum Annotation {
/// `<type>: <trait1> + <trait2> + ...`
Trait(String /* type */, Vec<String> /* traits */),
/// `<type> = <alias>`
Alias(String /* type */, String /* alias */),
/// `<type>: '<lifetime>`
Lifetime(String /* type */, String /* lifetime */),
}

Expand Down Expand Up @@ -79,6 +82,10 @@ impl Parse for AnnotationBody {
}
}

/// Parse the method call of the form `variable.function(args...)`
/// # Example:
/// - `x.my_method(1u8, "abc")` -> `("x", "my_method", ["1u8", "abc"])`
/// - `var.foo()` -> `("var", "foo", [])`
fn parse_call(input: ParseStream) -> Result<(String, String, Vec<String>), Error> {
let var = if input.peek(Ident) {
to_string(&input.parse::<Ident>()?)
Expand All @@ -104,6 +111,11 @@ fn parse_call(input: ParseStream) -> Result<(String, String, Vec<String>), Error
Ok((var, fn_.to_string(), args.iter().map(to_string).collect()))
}

/// Parse the variable type and argument types
/// # Example:
/// - `MyType; [u8, &str]` -> `("MyType", ["u8", "&str"])`
/// - `MyType; []` -> `("MyType", [])`
/// - `MyType` -> `("MyType", [])`
fn parse_types(input: ParseStream) -> Result<(String, Vec<String>), Error> {
let var_type: Type = input.parse()?;

Expand Down Expand Up @@ -131,6 +143,7 @@ fn parse_types(input: ParseStream) -> Result<(String, Vec<String>), Error> {
Ok((to_string(&var_type), args_types))
}

/// Parse the annotations
fn parse_annotations(input: ParseStream) -> Result<Vec<Annotation>, Error> {
input
.parse_terminated(Annotations::parse, Token![;])
Expand Down
10 changes: 9 additions & 1 deletion spec-trait-impl/crates/spec-trait-macro/src/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl TryFrom<(&Vec<ImplBody>, &Vec<TraitBody>, &AnnotationBody)> for SpecBody {
.filter_map(|impl_| {
let trait_ = traits.iter().find(|tr| tr.name == impl_.trait_name)?;

// ensure that the generics in the trait and impl match
if get_generics_types::<Vec<_>>(&trait_.generics).len()
!= get_generics_types::<Vec<_>>(&impl_.trait_generics).len()
|| get_generics_lifetimes::<Vec<_>>(&trait_.generics).len()
Expand Down Expand Up @@ -99,7 +100,7 @@ fn get_constraints(default: SpecBody) -> Option<SpecBody> {
None => Some(default),
// from when macro
Some(cond) => {
let var = VarBody::from(&default); // TODO: handle conflicting vars
let var = VarBody::from(&default); // TODO: handle conflicting vars (early return None)
let (satisfied, constraints) = satisfies_condition(cond, &var, &default.constraints);

if satisfied {
Expand All @@ -113,6 +114,8 @@ fn get_constraints(default: SpecBody) -> Option<SpecBody> {
}
}

/// checks if the given condition is satisfied by the given var info and constraints
/// if satisfied, it returns the updated constraints
fn satisfies_condition(
condition: &WhenCondition,
var: &VarBody,
Expand Down Expand Up @@ -262,6 +265,8 @@ fn satisfies_condition(
}
}

/// fills the constraints for each generic parameter from trait and specialized trait,
/// based on the constraints from impl and type
fn fill_trait_constraints(constraints: &Constraints, spec: &SpecBody) -> Constraints {
let mut constraints = constraints.clone();
constraints.from_trait = get_trait_constraints(&spec.trait_, &spec.impl_, &constraints);
Expand All @@ -278,6 +283,7 @@ fn fill_trait_constraints(constraints: &Constraints, spec: &SpecBody) -> Constra
constraints
}

/// gets the constraints for each generic parameter from trait based on the constraints from impl
fn get_trait_constraints(
trait_: &TraitBody,
impl_: &ImplBody,
Expand Down Expand Up @@ -324,6 +330,7 @@ impl From<&SpecBody> for TokenStream {
}
}

/// gets the types for the generics in the spec body
pub fn get_types_for_generics(spec: &SpecBody) -> TokenStream {
let trait_body = spec
.trait_
Expand All @@ -344,6 +351,7 @@ pub fn get_types_for_generics(spec: &SpecBody) -> TokenStream {
}
}

/// gets the type for a generic parameter based on the constraints
fn get_type(generic: &str, constraints: &Constraints) -> String {
constraints
.from_impl
Expand Down
8 changes: 5 additions & 3 deletions spec-trait-impl/crates/spec-trait-macro/src/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl From<&SpecBody> for VarBody {
}
}

/// Get type aliases from annotations
pub fn get_type_aliases(ann: &[Annotation]) -> Aliases {
let mut aliases = Aliases::new();

Expand All @@ -63,6 +64,7 @@ pub fn get_type_aliases(ann: &[Annotation]) -> Aliases {
aliases
}

/// Get variable information for each generic in the impl
fn get_vars(
ann: &AnnotationBody,
impl_: &ImplBody,
Expand All @@ -81,7 +83,7 @@ fn get_vars(
);

match trait_.get_corresponding_generic(&str_to_generics(&impl_.trait_generics), g) {
// get type
// get type from trait (generic already existed before specialization)
Some(trait_generic) => {
let from_trait = get_generic_constraints_from_trait(
&trait_generic,
Expand All @@ -94,7 +96,7 @@ fn get_vars(
from_trait.into_iter().chain(from_type).collect::<Vec<_>>()
}

// get from specialized instead
// get from specialized trait (generic was added during specialization)
None => {
let trait_generic = trait_
.specialized
Expand Down Expand Up @@ -136,7 +138,7 @@ fn get_vars(
/**
Get the parameter types from a trait function.
# Example
`fn foo(&self, x: T, y: u32);` returns `vec!["T", "u32"]`
- `fn foo(&self, x: T, y: u32);` -> `vec!["T", "u32"]`
*/
fn get_param_types(trait_fn: &TraitItemFn) -> Vec<String> {
trait_fn
Expand Down
4 changes: 4 additions & 0 deletions spec-trait-impl/crates/spec-trait-order/src/aliases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use syn::{Item, Path, UseTree};
const MACRO_PACKAGE: &str = "spec_trait_macro";
const MACRO_NAME: &str = "when";

/// Collect all aliases for the `when` macro from the given items.
pub fn collect_when_aliases(items: &[Item]) -> HashSet<String> {
let mut set = HashSet::new();

Expand All @@ -16,6 +17,7 @@ pub fn collect_when_aliases(items: &[Item]) -> HashSet<String> {
set
}

/// Recursively traverse the `UseTree` to find aliases for the `when` macro.
fn collect_aliases_from_tree(tree: &UseTree, prefix_spec: bool, set: &mut HashSet<String>) {
match tree {
// `use spec_trait_macro::...`
Expand Down Expand Up @@ -50,6 +52,8 @@ fn collect_aliases_from_tree(tree: &UseTree, prefix_spec: bool, set: &mut HashSe
}
}

/// Check if the given path corresponds to the `when` macro,
/// considering both direct usage and aliases collected earlier.
pub fn is_when_macro(path: &Path, when_aliases: &HashSet<String>) -> bool {
// `when` imported directly or via alias
when_aliases.contains(&path.segments.last().unwrap().ident.to_string()) ||
Expand Down
6 changes: 6 additions & 0 deletions spec-trait-impl/crates/spec-trait-order/src/crates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ use std::path::{Path, PathBuf};

#[derive(Debug)]
pub struct Crate {
/// The name of the crate
pub name: String,
/// The parsed content of the crate
pub content: CrateCache,
/// The list of source files paths in the crate
#[cfg(test)]
files: Vec<PathBuf>,
}
Expand All @@ -27,6 +30,7 @@ pub fn get_crates(dir: &Path) -> Vec<Crate> {
.collect()
}

/// Extract crate information from the `[package]` section of Cargo.toml
fn get_crate_from_package(value: &toml::Value, dir: &Path) -> Option<Crate> {
let package = value.get("package")?;
let name = package.get("name")?.as_str()?;
Expand All @@ -43,6 +47,8 @@ fn get_crate_from_package(value: &toml::Value, dir: &Path) -> Option<Crate> {
})
}

/// Extract information from the `[workspace]` members section of Cargo.toml
/// and for each member, gather crate information
fn get_crates_from_workspace_members(value: &toml::Value, dir: &Path) -> Vec<Crate> {
let members = value
.get("workspace")
Expand Down
4 changes: 4 additions & 0 deletions spec-trait-impl/crates/spec-trait-utils/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub fn write_cache(cache: &CrateCache, crate_name: Option<String>) {
write_top_level_cache(&top_level_cache);
}

/// Initialize an empty cache
pub fn reset() {
let empty_cache = Cache::new();
write_top_level_cache(&empty_cache);
Expand All @@ -65,11 +66,13 @@ pub fn add_impl(imp: ImplBody) {
write_cache(&cache, None);
}

/// Get the trait with the given name
pub fn get_trait_by_name(trait_name: &str) -> Option<TraitBody> {
let cache = read_cache(None);
cache.traits.into_iter().find(|tr| tr.name == trait_name)
}

/// Get all traits that have a function with the given name and number of arguments
pub fn get_traits_by_fn(fn_name: &str, args_len: usize) -> Vec<TraitBody> {
let cache = read_cache(None);
cache
Expand All @@ -79,6 +82,7 @@ pub fn get_traits_by_fn(fn_name: &str, args_len: usize) -> Vec<TraitBody> {
.collect()
}

/// Get all impls for the given type that implement any of the given traits
pub fn get_impls_by_type_and_traits(
type_name: &str,
traits: &[TraitBody],
Expand Down
5 changes: 5 additions & 0 deletions spec-trait-impl/crates/spec-trait-utils/src/conditions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ use syn::{Error, Ident, Token, parenthesized};

#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
pub enum WhenCondition {
/// `<generic> = <type>`
Type(
String, /* generic */
String, /* type (without lifetime) */
),
/// `<generic>: <trait1> + <trait2> + ...`
Trait(String /* generic */, Vec<String> /* traits */),
/// `all(<cond1>, <cond2>, ...)`
All(Vec<WhenCondition>),
/// `any(<cond1>, <cond2>, ...)`
Any(Vec<WhenCondition>),
/// `not(<cond>)`
Not(Box<WhenCondition>),
}

Expand Down
3 changes: 3 additions & 0 deletions spec-trait-impl/crates/spec-trait-utils/src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ impl ImplBody {
self.impl_generics = to_string(&generics);
}

// WhenCondition::Any should never appear here
// WhenCondition::Not can't be applied directly
_ => {}
}
}
Expand All @@ -239,6 +241,7 @@ impl ImplBody {
for trait `TraitName<A, B>` and impl `impl<T, U> TraitName<T, U> for MyType`
- trait_generic = A -> trait_generic = T
- trait_generic = B -> trait_generic = U
- trait_generic = C -> None
*/
pub fn get_corresponding_generic(
&self,
Expand Down
7 changes: 7 additions & 0 deletions spec-trait-impl/crates/spec-trait-utils/src/specialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub trait Specializable {
fn handle_items_visit<V: for<'a> Visit<'a>>(&self, visitor: &mut V);
}

/// Filters the given conditions to return only those that are assignable.
/// If there are multiple type conditions for the same generic, only the most specific one is kept.
/// If there are conflicting types that are not mutually assignable, they are not included.
pub fn get_assignable_conditions(
conditions: &[WhenCondition],
generics: &str,
Expand All @@ -38,6 +41,7 @@ pub fn get_assignable_conditions(
&& !type_assignable(other_t, t, generics, &Aliases::default())
});

// if there are no conflicts, keep only the most specific type
if diff_types || !most_specific {
None
} else {
Expand All @@ -49,6 +53,8 @@ pub fn get_assignable_conditions(
.collect()
}

/// Returns the list of types assigned to the given generic in the conditions,
/// ordered from least specific to most specific.
fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition]) -> Vec<String> {
let mut types = conditions
.iter()
Expand All @@ -57,6 +63,7 @@ fn get_generic_types_from_conditions(generic: &str, conditions: &[WhenCondition]
_ => None,
})
.collect::<Vec<_>>();
// `Vec<_>` is less specific than `Vec<String>`
types.sort_by_key(|t| t.replace("_", "").len());
types
}
Expand Down
11 changes: 6 additions & 5 deletions spec-trait-impl/crates/spec-trait-utils/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ impl TraitBody {
}

/**
get the generic in the trait corresponding to the impl_generic in the impl
# Example:
for trait `TraitName<A, B>` and impl `impl<T, U> TraitName<T, U> for MyType`
- impl_generic = T -> trait_generic = A
- impl_generic = U -> trait_generic = B
get the generic in the trait corresponding to the impl_generic in the impl
# Example:
for trait `TraitName<A, B>` and impl `impl<T, U> TraitName<T, U> for MyType`
- impl_generic = T -> trait_generic = A
- impl_generic = U -> trait_generic = B
- impl_generic = C -> None
*/
pub fn get_corresponding_generic(
&self,
Expand Down
Loading