diff --git a/src/modules/function/invocation_utils.rs b/src/modules/function/invocation_utils.rs index 769ccce4..ed8dbdc7 100644 --- a/src/modules/function/invocation_utils.rs +++ b/src/modules/function/invocation_utils.rs @@ -39,6 +39,17 @@ fn run_function_with_args(meta: &mut ParserMetadata, mut fun: FunctionDecl, args // Check if the function argument types match if fun.is_args_typed { for (index, (arg_name, arg_type, given_type)) in izip!(fun.arg_names.iter(), fun.arg_types.iter(), args.iter()).enumerate() { + + // union type matching works differently in functions + if let Type::Union(union) = &arg_type { + if ! union.has(given_type) { + let fun_name = &fun.name; + let ordinal = ordinal_number(index); + return error!(meta, tok, format!("{ordinal} argument of function '{fun_name}' does not allow '{given_type}' in '{arg_type}'")) + } + continue; + } + if !given_type.is_allowed_in(arg_type) { let fun_name = &fun.name; let ordinal = ordinal_number(index); diff --git a/src/modules/types.rs b/src/modules/types.rs index 825b8a1f..f93464d1 100644 --- a/src/modules/types.rs +++ b/src/modules/types.rs @@ -1,14 +1,19 @@ use std::fmt::Display; use heraclitus_compiler::prelude::*; +use itertools::Itertools; +use union::UnionType; use crate::utils::ParserMetadata; -#[derive(Debug, Clone, PartialEq, Eq, Default)] +mod union; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] pub enum Type { #[default] Null, Text, Bool, Num, + Union(UnionType), Array(Box), Failable(Box), Generic @@ -54,7 +59,12 @@ impl Display for Type { write!(f, "[{}]", t) }, Type::Failable(t) => write!(f, "{}?", t), - Type::Generic => write!(f, "Generic") + Type::Generic => write!(f, "Generic"), + + Type::Union(types) => { + let types: &Vec = types.into(); + write!(f, "{}", types.iter().map(|x| format!("{x}")).join(" | ")) + } } } } @@ -70,10 +80,8 @@ pub fn parse_type(meta: &mut ParserMetadata) -> Result { .map_err(|_| Failure::Loud(Message::new_err_at_token(meta, tok).message("Expected a data type"))) } -// Tries to parse the type - if it fails, it fails quietly -pub fn try_parse_type(meta: &mut ParserMetadata) -> Result { - let tok = meta.get_current_token(); - let res = match tok.clone() { +fn parse_type_tok(meta: &mut ParserMetadata, tok: Option) -> Result { + match tok.clone() { Some(matched_token) => { match matched_token.word.as_ref() { "Text" => { @@ -134,10 +142,35 @@ pub fn try_parse_type(meta: &mut ParserMetadata) -> Result { None => { Err(Failure::Quiet(PositionInfo::at_eof(meta))) } - }; + } +} +fn parse_one_type(meta: &mut ParserMetadata, tok: Option) -> Result { + let res = parse_type_tok(meta, tok)?; if token(meta, "?").is_ok() { - return res.map(|t| Type::Failable(Box::new(t))) + return Ok(Type::Failable(Box::new(res))) + } + Ok(res) +} + +// Tries to parse the type - if it fails, it fails quietly +pub fn try_parse_type(meta: &mut ParserMetadata) -> Result { + let tok = meta.get_current_token(); + let res = parse_one_type(meta, tok); + + if token(meta, "|").is_ok() { + // is union type + let mut unioned = vec![ res? ]; + loop { + match parse_one_type(meta, meta.get_current_token()) { + Err(err) => return Err(err), + Ok(t) => unioned.push(t) + }; + if token(meta, "|").is_err() { + break; + } + } + return Ok(Type::Union(unioned.into())) } res diff --git a/src/modules/types/union.rs b/src/modules/types/union.rs new file mode 100644 index 00000000..cd2b22e6 --- /dev/null +++ b/src/modules/types/union.rs @@ -0,0 +1,39 @@ +use super::Type; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct UnionType(pub Vec); + +impl UnionType { + pub fn has(&self, other: &Type) -> bool { + self.0.iter().find(|x| **x == *other).is_some() + } +} + +impl Into> for UnionType { + fn into(self) -> Vec { + self.0 + } +} + +impl <'a> Into<&'a Vec> for &'a UnionType { + fn into(self) -> &'a Vec { + &self.0 + } +} + +impl From> for UnionType { + fn from(value: Vec) -> Self { + let mut value = value; + value.sort(); + if value.len() < 2 { + unreachable!("A union type must have at least two elements") + } + for typ in &value { + if let Type::Union(_) = typ { + unreachable!("Union types cannot be nested") + } + } + + Self(value) + } +} diff --git a/src/tests/errors.rs b/src/tests/errors.rs index dadb73d3..404890a9 100644 --- a/src/tests/errors.rs +++ b/src/tests/errors.rs @@ -1,5 +1,7 @@ use super::test_amber; +mod unions; + #[test] #[should_panic(expected = "ERROR: Return type does not match function return type")] fn function_with_wrong_typed_return() { diff --git a/src/tests/errors/unions.rs b/src/tests/errors/unions.rs new file mode 100644 index 00000000..19639ffb --- /dev/null +++ b/src/tests/errors/unions.rs @@ -0,0 +1,32 @@ +use crate::tests::test_amber; + +#[test] +#[should_panic(expected = "ERROR: 1st argument 'param' of function 'abc' expects type 'Text | Null', but 'Num' was given")] +fn invalid_union_type_eq_normal_type() { + let code = r#" + fun abc(param: Text | Null) {} + abc("") + abc(123) + "#; + test_amber(code, ""); +} + +#[test] +#[should_panic(expected = "ERROR: 1st argument 'param' of function 'abc' expects type 'Text | Null', but 'Num | [Text]' was given")] +fn invalid_two_unions() { + let code = r#" + fun abc(param: Text | Null) {} + abc(123 as Num | [Text]) + "#; + test_amber(code, ""); +} + +#[test] +#[should_panic(expected = "ERROR: 1st argument 'param' of function 'abc' expects type 'Text | Num | Text? | Num? | [Null]', but 'Null' was given")] +fn big_union() { + let code = r#" + fun abc(param: Text | Num | Text? | Num? | [Null]) {} + abc(null) + "#; + test_amber(code, ""); +} \ No newline at end of file diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 1b62c2fb..d8d73155 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -8,6 +8,7 @@ use std::process::{Command, Stdio}; pub mod cli; pub mod errors; pub mod extra; +pub mod types; pub mod postprocessor; mod stdlib; mod validity; diff --git a/src/tests/types/eq.rs b/src/tests/types/eq.rs new file mode 100644 index 00000000..a19cb1ea --- /dev/null +++ b/src/tests/types/eq.rs @@ -0,0 +1,34 @@ +use crate::modules::types::Type; + +#[test] +fn normal_types_eq() { + let types = vec![Type::Null, Type::Text, Type::Bool, Type::Num, Type::Generic]; + for typ in types { + assert_eq!(typ, typ, "{typ} and {typ} must be equal!"); + } +} + +#[test] +fn two_different_normal_types() { + assert_ne!(Type::Null, Type::Bool); +} + +#[test] +fn normal_and_failable_type() { + assert_ne!(Type::Failable(Box::new(Type::Text)), Type::Text, "Text? and Text must not be equal!") +} + +#[test] +fn array_and_normal_type() { + assert_ne!(Type::Array(Box::new(Type::Bool)), Type::Bool); +} + +#[test] +fn array_and_array_of_failables() { + assert_ne!(Type::Array(Box::new(Type::Bool)), Type::Array(Box::new(Type::Failable(Box::new(Type::Bool))))); +} + +#[test] +fn nested_array_normal_array_with_failable() { + assert_ne!(Type::Array(Box::new(Type::Array(Box::new(Type::Bool)))), Type::Failable(Box::new(Type::Bool))); +} \ No newline at end of file diff --git a/src/tests/types/mod.rs b/src/tests/types/mod.rs new file mode 100644 index 00000000..40991751 --- /dev/null +++ b/src/tests/types/mod.rs @@ -0,0 +1,2 @@ +pub mod union; +pub mod eq; diff --git a/src/tests/types/union.rs b/src/tests/types/union.rs new file mode 100644 index 00000000..5c961d3f --- /dev/null +++ b/src/tests/types/union.rs @@ -0,0 +1,73 @@ +use crate::modules::types::Type; + +#[test] +fn partially_overlapping_types() { + let one = Type::Union(vec![Type::Text, Type::Num]); + let two = Type::Union(vec![Type::Num, Type::Null]); + + assert_ne!(one, two, "Text | Num must not be equal to Num | Null!") +} + +#[test] +fn overlapping_types() { + let one = Type::Union(vec![Type::Text, Type::Num]); + let two = Type::Union(vec![Type::Text, Type::Num, Type::Null]); + + assert_eq!(one, two, "Text | Num must be equal to Text | Num | Null!") +} + +#[test] +fn same_union() { + let one = Type::Union(vec![Type::Text, Type::Num]); + let two = Type::Union(vec![Type::Text, Type::Num]); + + assert_eq!(one, two, "Text | Num must be equal to Text | Num!") +} + +#[test] +fn empty_union() { + let one = Type::Union(vec![]); + let two = Type::Union(vec![]); + + assert_eq!(one, two, "If one of unions is empty, it must always be equal to another") +} + +#[test] +fn empty_and_normal_union() { + let one = Type::Union(vec![Type::Text, Type::Num]); + let two = Type::Union(vec![]); + + assert_eq!(one, two, "If one of unions is empty, it must always be equal to another") +} + +#[test] +fn empty_union_and_normal_type() { + let one = Type::Union(vec![]); + let two = Type::Text; + + assert_ne!(one, two, "An empty union and one type are not equal") +} + +#[test] +fn big_union() { + let one = Type::Union(vec![Type::Text, Type::Text, Type::Text, Type::Text, Type::Text, Type::Text, Type::Text, Type::Num]); + let two = Type::Union(vec![Type::Text, Type::Num]); + + assert_eq!(one, two, "Text | Text | ... | Text | Num and Text | Num must be equal!") +} + +#[test] +fn normal_and_union() { + let one = Type::Text; + let two = Type::Union(vec![Type::Text, Type::Null]); + + assert_eq!(one, two, "Text and Text | Null must be equal!"); +} + +#[test] +fn normal_not_in_union() { + let one = Type::Text; + let two = Type::Union(vec![Type::Num, Type::Null]); + + assert_ne!(one, two, "Text and Num | Null must not be equal!"); +} diff --git a/src/tests/validity/function_with_union_types.ab b/src/tests/validity/function_with_union_types.ab new file mode 100644 index 00000000..d0797888 --- /dev/null +++ b/src/tests/validity/function_with_union_types.ab @@ -0,0 +1,10 @@ +// Output +// abc +// 123 + +fun check(thing: Text | Num): Null { + echo thing +} + +check("abc") +check(123) \ No newline at end of file diff --git a/src/tests/validity/union_types.ab b/src/tests/validity/union_types.ab new file mode 100644 index 00000000..387c0d27 --- /dev/null +++ b/src/tests/validity/union_types.ab @@ -0,0 +1,7 @@ +// Output +// 123 + +let thingy = "abc" as Text | Num; +thingy = 123; + +echo thingy; \ No newline at end of file diff --git a/src/tests/validity/union_types_if.ab b/src/tests/validity/union_types_if.ab new file mode 100644 index 00000000..5a913a5a --- /dev/null +++ b/src/tests/validity/union_types_if.ab @@ -0,0 +1,19 @@ +// Output +// is text +// abc +// is num +// 123 + +fun check(thing: Text | Num): Null { + if thing is Text { + echo "is text" + echo thing + } + if thing is Num { + echo "is num" + echo thing + } +} + +check("abc") +check(123) \ No newline at end of file