diff --git a/shank-idl/src/idl_instruction.rs b/shank-idl/src/idl_instruction.rs index c2e692e..ea833e2 100644 --- a/shank-idl/src/idl_instruction.rs +++ b/shank-idl/src/idl_instruction.rs @@ -5,6 +5,7 @@ use heck::MixedCase; use serde::{Deserialize, Serialize}; use shank_macro_impl::instruction::{ Instruction, InstructionAccount, InstructionVariant, + InstructionVariantFields, }; use crate::{idl_field::IdlField, idl_type::IdlType}; @@ -49,27 +50,48 @@ impl TryFrom for IdlInstruction { fn try_from(variant: InstructionVariant) -> Result { let InstructionVariant { ident, - field_ty, + field_tys, accounts, discriminant, } = variant; let name = ident.to_string(); - let args: Vec = if let Some(field_ty) = field_ty { - let name = if field_ty.kind.is_custom() { - field_ty.ident.to_string().to_mixed_case() - } else { - "instructionArgs".to_string() - }; - let ty = IdlType::try_from(field_ty)?; - vec![IdlField { - name, - ty, - attrs: None, - }] - } else { - vec![] + let parsed_idl_fields: Result, Error> = match field_tys { + InstructionVariantFields::Named(args) => { + let mut parsed: Vec = vec![]; + for (field_name, field_ty) in args.iter() { + let ty = IdlType::try_from(field_ty.clone())?; + parsed.push(IdlField { + name: field_name.to_mixed_case(), + ty, + attrs: None, + }) + } + Ok(parsed) + } + InstructionVariantFields::Unnamed(args) => { + let mut parsed: Vec = vec![]; + for (index, field_ty) in args.iter().enumerate() { + let name = if args.len() == 1 { + if field_ty.kind.is_custom() { + field_ty.ident.to_string().to_mixed_case() + } else { + "args".to_string() + } + } else { + format!("arg{}", index).to_string() + }; + let ty = IdlType::try_from(field_ty.clone())?; + parsed.push(IdlField { + name, + ty, + attrs: None, + }) + } + Ok(parsed) + } }; + let args: Vec = parsed_idl_fields?; let accounts = accounts.into_iter().map(IdlAccountItem::from).collect(); ensure!( diff --git a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_args.json b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_args.json index 307a6b7..32a3093 100644 --- a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_args.json +++ b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_args.json @@ -37,7 +37,7 @@ ], "args": [ { - "name": "instructionArgs", + "name": "args", "type": { "option": "u8" } diff --git a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.json b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.json new file mode 100644 index 0000000..530927d --- /dev/null +++ b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.json @@ -0,0 +1,43 @@ +{ + "version": "", + "name": "", + "instructions": [ + { + "name": "CloseThing", + "accounts": [ + { + "name": "creator", + "isMut": false, + "isSigner": true + } + ], + "args": [ + { + "name": "arg0", + "type": { + "option": "u8" + } + }, + { + "name": "arg1", + "type": { + "defined": "ComplexArgs" + } + }, + { + "name": "arg2", + "type": { + "defined": "ComplexArgs" + } + } + ], + "discriminant": { + "type": "u8", + "value": 0 + } + } + ], + "metadata": { + "origin": "shank" + } +} diff --git a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.rs b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.rs new file mode 100644 index 0000000..0a2da3b --- /dev/null +++ b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_multiple_args.rs @@ -0,0 +1,5 @@ +#[derive(ShankInstruction)] +pub enum Instruction { + #[account(0, name = "creator", sig)] + CloseThing(Option, ComplexArgs, ComplexArgs), +} diff --git a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.json b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.json new file mode 100644 index 0000000..bfb731d --- /dev/null +++ b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.json @@ -0,0 +1,64 @@ +{ + "version": "", + "name": "", + "instructions": [ + { + "name": "CreateThing", + "accounts": [ + { + "name": "creator", + "isMut": false, + "isSigner": true + }, + { + "name": "thing", + "isMut": true, + "isSigner": false + } + ], + "args": [ + { + "name": "someArgs", + "type": { + "defined": "SomeArgs" + } + }, + { + "name": "otherArgs", + "type": { + "defined": "OtherArgs" + } + } + ], + "discriminant": { + "type": "u8", + "value": 0 + } + }, + { + "name": "CloseThing", + "accounts": [ + { + "name": "creator", + "isMut": false, + "isSigner": true + } + ], + "args": [ + { + "name": "args", + "type": { + "option": "u8" + } + } + ], + "discriminant": { + "type": "u8", + "value": 1 + } + } + ], + "metadata": { + "origin": "shank" + } +} diff --git a/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.rs b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.rs new file mode 100644 index 0000000..2679c90 --- /dev/null +++ b/shank-idl/tests/fixtures/instructions/single_file/instruction_with_struct_args.rs @@ -0,0 +1,11 @@ +#[derive(ShankInstruction)] +pub enum Instruction { + #[account(0, name = "creator", sig)] + #[account(1, name = "thing", mut)] + CreateThing { + some_args: SomeArgs, + other_args: OtherArgs, + }, + #[account(0, name = "creator", sig)] + CloseThing(Option), +} diff --git a/shank-idl/tests/instructions.rs b/shank-idl/tests/instructions.rs index e62f8a0..071f390 100644 --- a/shank-idl/tests/instructions.rs +++ b/shank-idl/tests/instructions.rs @@ -43,6 +43,40 @@ fn instruction_from_single_file_with_args() { assert_eq!(idl, expected_idl); } +#[test] +fn instruction_from_single_file_with_struct_args() { + let file = fixtures_dir() + .join("single_file") + .join("instruction_with_struct_args.rs"); + let idl = parse_file(&file, &ParseIdlConfig::optional_program_address()) + .expect("Parsing should not fail") + .expect("File contains IDL"); + + let expected_idl: Idl = serde_json::from_str(include_str!( + "./fixtures/instructions/single_file/instruction_with_struct_args.json" + )) + .unwrap(); + + assert_eq!(idl, expected_idl); +} + +#[test] +fn instruction_from_single_file_with_multiple_args() { + let file = fixtures_dir() + .join("single_file") + .join("instruction_with_multiple_args.rs"); + let idl = parse_file(&file, &ParseIdlConfig::optional_program_address()) + .expect("Parsing should not fail") + .expect("File contains IDL"); + + let expected_idl: Idl = serde_json::from_str(include_str!( + "./fixtures/instructions/single_file/instruction_with_multiple_args.json" + )) + .unwrap(); + + assert_eq!(idl, expected_idl); +} + #[test] fn instruction_from_single_file_with_optional_account() { let file = fixtures_dir() diff --git a/shank-macro-impl/src/instruction/instruction.rs b/shank-macro-impl/src/instruction/instruction.rs index fd5a30a..5010de7 100644 --- a/shank-macro-impl/src/instruction/instruction.rs +++ b/shank-macro-impl/src/instruction/instruction.rs @@ -70,13 +70,19 @@ impl TryFrom<&ParsedEnum> for Instruction { } } +#[derive(Debug)] +pub enum InstructionVariantFields { + Unnamed(Vec), + Named(Vec<(String, RustType)>), +} + // ----------------- // Instruction Variant // ----------------- #[derive(Debug)] pub struct InstructionVariant { pub ident: Ident, - pub field_ty: Option, + pub field_tys: InstructionVariantFields, pub accounts: Vec, pub discriminant: usize, } @@ -93,19 +99,35 @@ impl TryFrom<&ParsedEnumVariant> for InstructionVariant { .. } = variant; - if fields.len() > 1 { - return Err(ParseError::new_spanned( - fields.get(1).map(|x| &x.rust_type.ident), - "An Instruction can only have one arg field", - )); - } - let field_ty = fields.first().map(|x| x.rust_type.clone()); + let field_tys: InstructionVariantFields = if fields.len() > 0 { + // Determine if the InstructionType is tuple or struct variant + let field = fields.get(0).unwrap(); + match &field.ident { + Some(_) => InstructionVariantFields::Named( + fields + .iter() + .map(|x| { + ( + x.ident.as_ref().unwrap().to_string(), + x.rust_type.clone(), + ) + }) + .collect(), + ), + None => InstructionVariantFields::Unnamed( + fields.iter().map(|x| x.rust_type.clone()).collect(), + ), + } + } else { + InstructionVariantFields::Unnamed(vec![]) + }; + let attrs: &[Attribute] = attrs.as_ref(); let accounts: InstructionAccounts = attrs.try_into()?; Ok(Self { ident: ident.clone(), - field_ty, + field_tys, accounts: accounts.0, discriminant: *discriminant, }) diff --git a/shank-macro-impl/src/instruction/instruction_test.rs b/shank-macro-impl/src/instruction/instruction_test.rs index 219843d..7f17c76 100644 --- a/shank-macro-impl/src/instruction/instruction_test.rs +++ b/shank-macro-impl/src/instruction/instruction_test.rs @@ -2,7 +2,10 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{ItemEnum, Result as ParseResult}; -use crate::types::{Primitive, RustType}; +use crate::{ + instruction::InstructionVariantFields, + types::{Primitive, RustType}, +}; use super::instruction::{Instruction, InstructionVariant}; @@ -16,12 +19,12 @@ fn assert_instruction_variant( variant: &InstructionVariant, name: &str, expected_discriminant: usize, - expected_field_ty: Option, + expected_field_tys: &Vec, accounts_len: usize, ) { let InstructionVariant { ident, - field_ty, + field_tys, accounts, discriminant, } = variant; @@ -29,7 +32,34 @@ fn assert_instruction_variant( assert_eq!(ident.to_string(), name); assert_eq!(discriminant, &expected_discriminant, "discriminant"); assert_eq!(accounts.len(), accounts_len, "accounts"); - assert_eq!(field_ty, &expected_field_ty, "field type"); + match field_tys { + InstructionVariantFields::Named(field_tys) => { + assert_eq!( + field_tys.len(), + expected_field_tys.len(), + "fields size" + ); + for field_idx in 0..expected_field_tys.len() { + let (_field_name, field_ty) = field_tys.get(field_idx).unwrap(); + let expected_field_ty = + expected_field_tys.get(field_idx).unwrap(); + assert_eq!(field_ty, expected_field_ty, "field type"); + } + } + InstructionVariantFields::Unnamed(field_tys) => { + assert_eq!( + field_tys.len(), + expected_field_tys.len(), + "fields size" + ); + for field_idx in 0..expected_field_tys.len() { + let field_ty = field_tys.get(field_idx).unwrap(); + let expected_field_ty = + expected_field_tys.get(field_idx).unwrap(); + assert_eq!(field_ty, expected_field_ty, "field type"); + } + } + } } #[test] @@ -62,8 +92,20 @@ fn parse_c_style_instruction() { "non-optional account of second variant" ); - assert_instruction_variant(&parsed.variants[0], "CreateThing", 0, None, 2); - assert_instruction_variant(&parsed.variants[1], "CloseThing", 1, None, 1); + assert_instruction_variant( + &parsed.variants[0], + "CreateThing", + 0, + &vec![], + 2, + ); + assert_instruction_variant( + &parsed.variants[1], + "CloseThing", + 1, + &vec![], + 1, + ); } #[test] @@ -82,12 +124,18 @@ fn parse_custom_field_variant_instruction() { assert_eq!(parsed.ident.to_string(), "Instruction", "enum ident"); assert_eq!(parsed.variants.len(), 2, "variants"); - assert_instruction_variant(&parsed.variants[0], "CreateThing", 0, None, 0); + assert_instruction_variant( + &parsed.variants[0], + "CreateThing", + 0, + &vec![], + 0, + ); assert_instruction_variant( &parsed.variants[1], "CloseThing", 1, - Some(RustType::owned_custom_value("CloseArgs", "CloseArgs")), + &vec![RustType::owned_custom_value("CloseArgs", "CloseArgs")], 1, ); } @@ -109,12 +157,18 @@ fn parse_u8_field_variant_instruction() { assert_eq!(parsed.ident.to_string(), "Instruction", "enum ident"); assert_eq!(parsed.variants.len(), 2, "variants"); - assert_instruction_variant(&parsed.variants[0], "CreateThing", 0, None, 1); + assert_instruction_variant( + &parsed.variants[0], + "CreateThing", + 0, + &vec![], + 1, + ); assert_instruction_variant( &parsed.variants[1], "CloseThing", 1, - Some(RustType::owned_primitive("u8", Primitive::U8)), + &vec![RustType::owned_primitive("u8", Primitive::U8)], 1, ); } diff --git a/shank-macro-impl/src/types/resolve_rust_ty.rs b/shank-macro-impl/src/types/resolve_rust_ty.rs index cc1f405..e5135d0 100644 --- a/shank-macro-impl/src/types/resolve_rust_ty.rs +++ b/shank-macro-impl/src/types/resolve_rust_ty.rs @@ -2,8 +2,9 @@ use std::{convert::TryFrom, ops::Deref}; use quote::format_ident; use syn::{ - spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprLit, GenericArgument, Ident, Lit, - Path, PathArguments, PathSegment, Type, TypeArray, TypePath, TypeTuple, + spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprLit, + GenericArgument, Ident, Lit, Path, PathArguments, PathSegment, Type, + TypeArray, TypePath, TypeTuple, }; use super::{Composite, ParsedReference, Primitive, TypeKind, Value}; @@ -50,16 +51,28 @@ impl RustType { context: RustTypeContext::Default, } } - pub fn owned_primitive>(ident: T, primitive: Primitive) -> Self { + pub fn owned_primitive>( + ident: T, + primitive: Primitive, + ) -> Self { RustType::owned(ident, TypeKind::Primitive(primitive)) } pub fn owned_string>(ident: T) -> Self { RustType::owned(ident, TypeKind::Value(Value::String)) } - pub fn owned_custom_value>(ident: T, value: &str) -> Self { - RustType::owned(ident, TypeKind::Value(Value::Custom(value.to_string()))) + pub fn owned_custom_value>( + ident: T, + value: &str, + ) -> Self { + RustType::owned( + ident, + TypeKind::Value(Value::Custom(value.to_string())), + ) } - pub fn owned_vec_primitive>(ident: T, primitive: Primitive) -> Self { + pub fn owned_vec_primitive>( + ident: T, + primitive: Primitive, + ) -> Self { RustType::owned( ident, TypeKind::Composite( @@ -83,7 +96,10 @@ impl RustType { ) } - pub fn owned_option_primitive>(ident: T, primitive: Primitive) -> Self { + pub fn owned_option_primitive>( + ident: T, + primitive: Primitive, + ) -> Self { RustType::owned( ident, TypeKind::Composite( @@ -133,13 +149,18 @@ fn len_from_expr(expr: &Expr) -> ParseResult { } } -pub fn resolve_rust_ty(ty: &Type, context: RustTypeContext) -> ParseResult { +pub fn resolve_rust_ty( + ty: &Type, + context: RustTypeContext, +) -> ParseResult { let (ty, reference) = match ty { Type::Reference(r) => { let pr = ParsedReference::from(r); (r.elem.as_ref(), pr) } - Type::Array(_) | Type::Path(_) | Type::Tuple(_) => (ty, ParsedReference::Owned), + Type::Array(_) | Type::Path(_) | Type::Tuple(_) => { + (ty, ParsedReference::Owned) + } ty => { eprintln!("{:#?}", ty); return Err(ParseError::new( @@ -156,7 +177,9 @@ pub fn resolve_rust_ty(ty: &Type, context: RustTypeContext) -> ParseResult { let (inner_ident, inner_kind) = match elem.deref() { - Type::Path(TypePath { path, .. }) => ident_and_kind_from_path(path), + Type::Path(TypePath { path, .. }) => { + ident_and_kind_from_path(path) + } _ => { return Err(ParseError::new( ty.span(), @@ -171,7 +194,8 @@ pub fn resolve_rust_ty(ty: &Type, context: RustTypeContext) -> ParseResult { @@ -257,30 +281,52 @@ fn ident_to_kind(ident: &Ident, arguments: &PathArguments) -> TypeKind { } // Composite Types - PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { + PathArguments::AngleBracketed(AngleBracketedGenericArguments { + args, + .. + }) => { match args.len() { // ----------------- // Single Type Parameter // ----------------- 1 => match &args[0] { GenericArgument::Type(ty) => match ident_str.as_str() { - "Vec" => match resolve_rust_ty(ty, RustTypeContext::CollectionItem) { - Ok(inner) => TypeKind::Composite(Composite::Vec, vec![inner]), - Err(_) => TypeKind::Composite(Composite::Vec, vec![]), - }, - "Option" => match resolve_rust_ty(ty, RustTypeContext::OptionItem) { - Ok(inner) => TypeKind::Composite(Composite::Option, vec![inner]), - Err(_) => TypeKind::Composite(Composite::Option, vec![]), + "Vec" => match resolve_rust_ty( + ty, + RustTypeContext::CollectionItem, + ) { + Ok(inner) => { + TypeKind::Composite(Composite::Vec, vec![inner]) + } + Err(_) => { + TypeKind::Composite(Composite::Vec, vec![]) + } }, - _ => match resolve_rust_ty(ty, RustTypeContext::CustomItem) { + "Option" => match resolve_rust_ty( + ty, + RustTypeContext::OptionItem, + ) { Ok(inner) => TypeKind::Composite( - Composite::Custom(ident_str.clone()), + Composite::Option, vec![inner], ), Err(_) => { - TypeKind::Composite(Composite::Custom(ident_str.clone()), vec![]) + TypeKind::Composite(Composite::Option, vec![]) } }, + _ => match resolve_rust_ty( + ty, + RustTypeContext::CustomItem, + ) { + Ok(inner) => TypeKind::Composite( + Composite::Custom(ident_str.clone()), + vec![inner], + ), + Err(_) => TypeKind::Composite( + Composite::Custom(ident_str.clone()), + vec![], + ), + }, }, _ => TypeKind::Unknown, }, @@ -288,35 +334,44 @@ fn ident_to_kind(ident: &Ident, arguments: &PathArguments) -> TypeKind { // Two Type Parameters // ----------------- 2 => match (&args[0], &args[1]) { - (GenericArgument::Type(ty1), GenericArgument::Type(ty2)) => { - match ident_str.as_str() { - ident if ident == "HashMap" || ident == "BTreeMap" => { - let inners = match ( - resolve_rust_ty(ty1, RustTypeContext::CollectionItem), - resolve_rust_ty(ty2, RustTypeContext::CollectionItem), - ) { - (Ok(inner1), Ok(inner2)) => vec![inner1, inner2], - (Ok(inner1), Err(_)) => vec![inner1], - (Err(_), Ok(inner2)) => vec![inner2], - (Err(_), Err(_)) => vec![], - }; + ( + GenericArgument::Type(ty1), + GenericArgument::Type(ty2), + ) => match ident_str.as_str() { + ident if ident == "HashMap" || ident == "BTreeMap" => { + let inners = match ( + resolve_rust_ty( + ty1, + RustTypeContext::CollectionItem, + ), + resolve_rust_ty( + ty2, + RustTypeContext::CollectionItem, + ), + ) { + (Ok(inner1), Ok(inner2)) => { + vec![inner1, inner2] + } + (Ok(inner1), Err(_)) => vec![inner1], + (Err(_), Ok(inner2)) => vec![inner2], + (Err(_), Err(_)) => vec![], + }; - let composite = if ident == "HashMap" { - Composite::HashMap - } else { - Composite::BTreeMap - }; - TypeKind::Composite(composite, inners) - } - _ => { - eprintln!("ident: {:#?}, args: {:#?}", ident, args); - todo!( + let composite = if ident == "HashMap" { + Composite::HashMap + } else { + Composite::BTreeMap + }; + TypeKind::Composite(composite, inners) + } + _ => { + eprintln!("ident: {:#?}, args: {:#?}", ident, args); + todo!( "Not yet handling custom angle bracketed types with {} type parameters", args.len() ) - } } - } + }, _ => TypeKind::Unknown, }, _ => {