diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 00625373..ef2a4436 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -28,7 +28,7 @@ pub struct TraitInput<'a> { /// Preprocessed `contract` macro input for non-trait impl block pub struct ImplInput<'a> { - attributes: &'a ContractArgs, + attributes: Option<&'a ContractArgs>, error: ContractErrorAttr, item: &'a ItemImpl, generics: Vec<&'a GenericParam>, @@ -144,7 +144,7 @@ impl<'a> TraitInput<'a> { } impl<'a> ImplInput<'a> { - pub fn new(attributes: &'a ContractArgs, item: &'a ItemImpl) -> Self { + pub fn new(item: &'a ItemImpl) -> Self { let generics = item.generics.params.iter().collect(); let parsed_attrs = ParsedSylviaAttributes::new(item.attrs.iter()); let error = parsed_attrs.error_attrs.unwrap_or_default(); @@ -153,7 +153,7 @@ impl<'a> ImplInput<'a> { let interfaces = Interfaces::new(item); Self { - attributes, + attributes: None, item, generics, error, @@ -163,6 +163,13 @@ impl<'a> ImplInput<'a> { } } + pub fn new_with_module(attributes: &'a ContractArgs, item: &'a ItemImpl) -> Self { + Self { + attributes: Some(attributes), + ..Self::new(item) + } + } + pub fn process(&self) -> TokenStream { match is_trait(self.item) { true => self.process_interface(), @@ -273,7 +280,7 @@ impl<'a> ImplInput<'a> { } fn emit_querier_for_bound_impl(&self) -> TokenStream { - let contract_module = self.attributes.module.as_ref(); + let contract_module = self.attributes.map(|contract_args| &contract_args.module); let variants_args = MsgVariants::::new(self.item.as_variants(), MsgType::Query, &[], &None); let associated_types = ImplAssociatedTypes::new(self.item); @@ -299,11 +306,17 @@ impl<'a> ImplInput<'a> { interfaces, .. } = self; - let contract_module = self.attributes.module.as_ref(); let generic_params = &self.generics; - if is_trait(item) { - ImplMtHelpers::new(item, generic_params, custom, interfaces, &contract_module).emit() + if let Some(contract_module) = self.attributes { + ImplMtHelpers::new( + item, + generic_params, + custom, + interfaces, + &contract_module.module, + ) + .emit() } else { ContractMtHelpers::new(item, generic_params, custom, override_entry_points.clone()) .emit() diff --git a/sylvia-derive/src/lib.rs b/sylvia-derive/src/lib.rs index 25a5a955..37588f1f 100644 --- a/sylvia-derive/src/lib.rs +++ b/sylvia-derive/src/lib.rs @@ -237,10 +237,13 @@ pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream { fn contract_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 { fn inner(attr: TokenStream2, item: TokenStream2) -> syn::Result { - let attrs: parser::ContractArgs = parse2(attr)?; let input: ItemImpl = parse2(item)?; - - let expanded = ImplInput::new(&attrs, &input).process(); + let expanded = if input.trait_.is_some() { + let attrs: parser::ContractArgs = parse2(attr)?; + ImplInput::new_with_module(&attrs, &input).process() + } else { + ImplInput::new(&input).process() + }; let input = StripInput.fold_item_impl(input); Ok(quote! { diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index 1b3bf67e..f05130c4 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -631,7 +631,7 @@ pub struct ImplMtHelpers<'a> { exec_variants: MsgVariants<'a, GenericParam>, query_variants: MsgVariants<'a, GenericParam>, where_clause: &'a Option, - contract_module: &'a Option<&'a Path>, + contract_module: &'a Path, contract_name: &'a Ident, } @@ -641,7 +641,7 @@ impl<'a> ImplMtHelpers<'a> { generic_params: &'a [&'a GenericParam], custom: &'a Custom, interfaces: &'a Interfaces, - contract_module: &'a Option<&'a Path>, + contract_module: &'a Path, ) -> Self { let where_clause = &source.generics.where_clause; let exec_variants = MsgVariants::new( @@ -766,10 +766,7 @@ impl<'a> ImplMtHelpers<'a> { &associated_items, ); - let contract_module = match contract_module { - Some(contract_module) => quote! { #contract_module :: }, - None => quote! {}, - }; + let contract_module = quote! { #contract_module :: }; let contract_proxy = Ident::new(&format!("{}Proxy", contract_name), contract_name.span()); let where_predicates = where_clause diff --git a/sylvia-derive/src/parser/contract.rs b/sylvia-derive/src/parser/contract.rs index 73155633..81481380 100644 --- a/sylvia-derive/src/parser/contract.rs +++ b/sylvia-derive/src/parser/contract.rs @@ -1,38 +1,35 @@ +use proc_macro_error::abort; use syn::{ - parse::{Error, Nothing, Parse, ParseStream}, - Ident, Path, Result, Token, + parse::{Nothing, Parse, ParseStream}, + Error, Ident, Path, Result, Token, }; /// Parsed arguments for `contract` macro pub struct ContractArgs { /// Module in which contract impl block is defined. /// Used only while implementing `Interface` on `Contract`. - pub module: Option, + pub module: Path, } impl Parse for ContractArgs { fn parse(input: ParseStream) -> Result { - let mut module = None; - - while !input.is_empty() { - let attr: Ident = input.parse()?; + let maybe_module = input.parse().and_then(|attr: Ident| -> Result { let _: Token![=] = input.parse()?; - if attr == "module" { - module = Some(input.parse()?); + input.parse() } else { - return Err(Error::new(attr.span(), "expected `module`")); - } - - if input.peek(Token![,]) { - let _: Token![,] = input.parse()?; - } else if !input.is_empty() { - return Err(input.error("Unexpected token, comma expected")); + Err(Error::new(attr.span(), "Missing `module` attribute")) } - } - + }); + let module: Path = match maybe_module { + Ok(module) => module, + Err(e) => abort!( + e.span(), "The module path needs to be provided `#[contract(module=path::to::contract)`."; + note = "Implementing interface on a contract requires to point the path to the contract structure."; + note = "Parsing error: {}", e + ), + }; let _: Nothing = input.parse()?; - Ok(ContractArgs { module }) } }