From 1577927965dcb702efd0f4ca3455752285cfb6f3 Mon Sep 17 00:00:00 2001 From: Jesse Hallett Date: Tue, 19 Nov 2024 11:25:57 -0800 Subject: [PATCH] support more aggregation operators when generating native query configurations (#120) Adds or refines support for these operators: Arithmetic Expression Operators - `$abs` - `$add` - `$divide` - `$multiply` - `$subtract` Array Expression Operators - `$arrayElemAt` Boolean Expression Operators - `$and` - `$not` - `$or` Comparison Expression Operators - `$eq` - `$gt` - `$gte` - `$lt` - `$lte` - `$ne` Set Expression Operators - `$allElementsTrue` - `$anyElementTrue` String Expression Operators - `$split` Trigonometry Expression Operators - `$sin` - `$cos` - `$tan` - `$asin` - `$acos` - `$atan` - `$asinh` - `$acosh` - `$atanh` - `$sinh` - `$cosh` - `$tanh` Accumulators (`$group`, `$bucket`, `$bucketAuto`, `$setWindowFields`) - `$avg` - `$count` - `$max` - `$min` - `$push` - `$sum` Also improves type inference to make all of these operators work. This is work an an in-progress feature that is gated behind a feature flag, `native-query-subcommand` --- Cargo.lock | 1 + crates/cli/Cargo.toml | 1 + .../aggregation-operator-progress.md | 280 +++++++ .../native_query/aggregation_expression.rs | 361 +++++++-- crates/cli/src/native_query/error.rs | 27 +- crates/cli/src/native_query/helpers.rs | 38 +- crates/cli/src/native_query/mod.rs | 171 +---- .../src/native_query/pipeline/match_stage.rs | 9 +- crates/cli/src/native_query/pipeline/mod.rs | 36 +- .../src/native_query/pipeline_type_context.rs | 72 +- crates/cli/src/native_query/tests.rs | 274 +++++++ .../cli/src/native_query/type_constraint.rs | 168 ++++- .../type_solver/constraint_to_type.rs | 40 +- .../cli/src/native_query/type_solver/mod.rs | 66 +- .../src/native_query/type_solver/simplify.rs | 686 +++++++++++------- .../native_query/type_solver/substitute.rs | 100 --- .../query/serialization/tests.txt | 1 + crates/mongodb-support/src/bson_type.rs | 15 +- crates/test-helpers/src/configuration.rs | 29 +- 19 files changed, 1698 insertions(+), 677 deletions(-) create mode 100644 crates/cli/src/native_query/aggregation-operator-progress.md create mode 100644 crates/cli/src/native_query/tests.rs delete mode 100644 crates/cli/src/native_query/type_solver/substitute.rs diff --git a/Cargo.lock b/Cargo.lock index 8e7d4980..fd7c146a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1806,6 +1806,7 @@ dependencies = [ "anyhow", "clap", "configuration", + "enum-iterator", "futures-util", "googletest", "indexmap 2.2.6", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index f57e0069..64fcfcad 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -14,6 +14,7 @@ mongodb-support = { path = "../mongodb-support" } anyhow = "1.0.80" clap = { version = "4.5.1", features = ["derive", "env"] } +enum-iterator = "^2.0.0" futures-util = "0.3.28" indexmap = { workspace = true } itertools = { workspace = true } diff --git a/crates/cli/src/native_query/aggregation-operator-progress.md b/crates/cli/src/native_query/aggregation-operator-progress.md new file mode 100644 index 00000000..16a4ef8d --- /dev/null +++ b/crates/cli/src/native_query/aggregation-operator-progress.md @@ -0,0 +1,280 @@ +Arithmetic Expression Operators + +- [x] $abs - Returns the absolute value of a number. +- [x] $add - Adds numbers to return the sum, or adds numbers and a date to return a new date. If adding numbers and a date, treats the numbers as milliseconds. Accepts any number of argument expressions, but at most, one expression can resolve to a date. +- [ ] $ceil - Returns the smallest integer greater than or equal to the specified number. +- [x] $divide - Returns the result of dividing the first number by the second. Accepts two argument expressions. +- [ ] $exp - Raises e to the specified exponent. +- [ ] $floor - Returns the largest integer less than or equal to the specified number. +- [ ] $ln - Calculates the natural log of a number. +- [ ] $log - Calculates the log of a number in the specified base. +- [ ] $log10 - Calculates the log base 10 of a number. +- [ ] $mod - Returns the remainder of the first number divided by the second. Accepts two argument expressions. +- [x] $multiply - Multiplies numbers to return the product. Accepts any number of argument expressions. +- [ ] $pow - Raises a number to the specified exponent. +- [ ] $round - Rounds a number to to a whole integer or to a specified decimal place. +- [ ] $sqrt - Calculates the square root. +- [x] $subtract - Returns the result of subtracting the second value from the first. If the two values are numbers, return the difference. If the two values are dates, return the difference in milliseconds. If the two values are a date and a number in milliseconds, return the resulting date. Accepts two argument expressions. If the two values are a date and a number, specify the date argument first as it is not meaningful to subtract a date from a number. +- [ ] $trunc - Truncates a number to a whole integer or to a specified decimal place. + +Array Expression Operators + +- [x] $arrayElemAt - Returns the element at the specified array index. +- [ ] $arrayToObject - Converts an array of key value pairs to a document. +- [ ] $concatArrays - Concatenates arrays to return the concatenated array. +- [ ] $filter - Selects a subset of the array to return an array with only the elements that match the filter condition. +- [ ] $firstN - Returns a specified number of elements from the beginning of an array. Distinct from the $firstN accumulator. +- [ ] $in - Returns a boolean indicating whether a specified value is in an array. +- [ ] $indexOfArray - Searches an array for an occurrence of a specified value and returns the array index of the first occurrence. Array indexes start at zero. +- [ ] $isArray - Determines if the operand is an array. Returns a boolean. +- [ ] $lastN - Returns a specified number of elements from the end of an array. Distinct from the $lastN accumulator. +- [ ] $map - Applies a subexpression to each element of an array and returns the array of resulting values in order. Accepts named parameters. +- [ ] $maxN - Returns the n largest values in an array. Distinct from the $maxN accumulator. +- [ ] $minN - Returns the n smallest values in an array. Distinct from the $minN accumulator. +- [ ] $objectToArray - Converts a document to an array of documents representing key-value pairs. +- [ ] $range - Outputs an array containing a sequence of integers according to user-defined inputs. +- [ ] $reduce - Applies an expression to each element in an array and combines them into a single value. +- [ ] $reverseArray - Returns an array with the elements in reverse order. +- [ ] $size - Returns the number of elements in the array. Accepts a single expression as argument. +- [ ] $slice - Returns a subset of an array. +- [ ] $sortArray - Sorts the elements of an array. +- [ ] $zip - Merge two arrays together. + +Bitwise Operators + +- [ ] $bitAnd - Returns the result of a bitwise and operation on an array of int or long values. +- [ ] $bitNot - Returns the result of a bitwise not operation on a single argument or an array that contains a single int or long value. +- [ ] $bitOr - Returns the result of a bitwise or operation on an array of int or long values. +- [ ] $bitXor - Returns the result of a bitwise xor (exclusive or) operation on an array of int and long values. + +Boolean Expression Operators + +- [x] $and - Returns true only when all its expressions evaluate to true. Accepts any number of argument expressions. +- [x] $not - Returns the boolean value that is the opposite of its argument expression. Accepts a single argument expression. +- [x] $or - Returns true when any of its expressions evaluates to true. Accepts any number of argument expressions. + +Comparison Expression Operators + +- [ ] $cmp - Returns 0 if the two values are equivalent, 1 if the first value is greater than the second, and -1 if the first value is less than the second. +- [x] $eq - Returns true if the values are equivalent. +- [x] $gt - Returns true if the first value is greater than the second. +- [x] $gte - Returns true if the first value is greater than or equal to the second. +- [x] $lt - Returns true if the first value is less than the second. +- [x] $lte - Returns true if the first value is less than or equal to the second. +- [x] $ne - Returns true if the values are not equivalent. + +Conditional Expression Operators + +- [ ] $cond - A ternary operator that evaluates one expression, and depending on the result, returns the value of one of the other two expressions. Accepts either three expressions in an ordered list or three named parameters. +- [ ] $ifNull - Returns either the non-null result of the first expression or the result of the second expression if the first expression results in a null result. Null result encompasses instances of undefined values or missing fields. Accepts two expressions as arguments. The result of the second expression can be null. +- [ ] $switch - Evaluates a series of case expressions. When it finds an expression which evaluates to true, $switch executes a specified expression and breaks out of the control flow. + +Custom Aggregation Expression Operators + +- [ ] $accumulator - Defines a custom accumulator function. +- [ ] $function - Defines a custom function. + +Data Size Operators + +- [ ] $binarySize - Returns the size of a given string or binary data value's content in bytes. +- [ ] $bsonSize - Returns the size in bytes of a given document (i.e. bsontype Object) when encoded as BSON. + +Date Expression Operators + +- [ ] $dateAdd - Adds a number of time units to a date object. +- [ ] $dateDiff - Returns the difference between two dates. +- [ ] $dateFromParts - Constructs a BSON Date object given the date's constituent parts. +- [ ] $dateFromString - Converts a date/time string to a date object. +- [ ] $dateSubtract - Subtracts a number of time units from a date object. +- [ ] $dateToParts - Returns a document containing the constituent parts of a date. +- [ ] $dateToString - Returns the date as a formatted string. +- [ ] $dateTrunc - Truncates a date. +- [ ] $dayOfMonth - Returns the day of the month for a date as a number between 1 and 31. +- [ ] $dayOfWeek - Returns the day of the week for a date as a number between 1 (Sunday) and 7 (Saturday). +- [ ] $dayOfYear - Returns the day of the year for a date as a number between 1 and 366 (leap year). +- [ ] $hour - Returns the hour for a date as a number between 0 and 23. +- [ ] $isoDayOfWeek - Returns the weekday number in ISO 8601 format, ranging from 1 (for Monday) to 7 (for Sunday). +- [ ] $isoWeek - Returns the week number in ISO 8601 format, ranging from 1 to 53. Week numbers start at 1 with the week (Monday through Sunday) that contains the year's first Thursday. +- [ ] $isoWeekYear - Returns the year number in ISO 8601 format. The year starts with the Monday of week 1 (ISO 8601) and ends with the Sunday of the last week (ISO 8601). +- [ ] $millisecond - Returns the milliseconds of a date as a number between 0 and 999. +- [ ] $minute - Returns the minute for a date as a number between 0 and 59. +- [ ] $month - Returns the month for a date as a number between 1 (January) and 12 (December). +- [ ] $second - Returns the seconds for a date as a number between 0 and 60 (leap seconds). +- [ ] $toDate - Converts value to a Date. +- [ ] $week - Returns the week number for a date as a number between 0 (the partial week that precedes the first Sunday of the year) and 53 (leap year). +- [ ] $year - Returns the year for a date as a number (e.g. 2014). + +The following arithmetic operators can take date operands: + +- [ ] $add - Adds numbers and a date to return a new date. If adding numbers and a date, treats the numbers as milliseconds. Accepts any number of argument expressions, but at most, one expression can resolve to a date. +- [ ] $subtract - Returns the result of subtracting the second value from the first. If the two values are dates, return the difference in milliseconds. If the two values are a date and a number in milliseconds, return the resulting date. Accepts two argument expressions. If the two values are a date and a number, specify the date argument first as it is not meaningful to subtract a date from a number. + +Literal Expression Operator + +- [ ] $literal - Return a value without parsing. Use for values that the aggregation pipeline may interpret as an expression. For example, use a $literal expression to a string that starts with a dollar sign ($) to avoid parsing as a field path. + +Miscellaneous Operators + +- [ ] $getField - Returns the value of a specified field from a document. You can use $getField to retrieve the value of fields with names that contain periods (.) or start with dollar signs ($). +- [ ] $rand - Returns a random float between 0 and 1 +- [ ] $sampleRate - Randomly select documents at a given rate. Although the exact number of documents selected varies on each run, the quantity chosen approximates the sample rate expressed as a percentage of the total number of documents. +- [ ] $toHashedIndexKey - Computes and returns the hash of the input expression using the same hash function that MongoDB uses to create a hashed index. + +Object Expression Operators + +- [ ] $mergeObjects - Combines multiple documents into a single document. +- [ ] $objectToArray - Converts a document to an array of documents representing key-value pairs. +- [ ] $setField - Adds, updates, or removes a specified field in a document. You can use $setField to add, update, or remove fields with names that contain periods (.) or start with dollar signs ($). + +Set Expression Operators + +- [x] $allElementsTrue - Returns true if no element of a set evaluates to false, otherwise, returns false. Accepts a single argument expression. +- [x] $anyElementTrue - Returns true if any elements of a set evaluate to true; otherwise, returns false. Accepts a single argument expression. +- [ ] $setDifference - Returns a set with elements that appear in the first set but not in the second set; i.e. performs a relative complement of the second set relative to the first. Accepts exactly two argument expressions. +- [ ] $setEquals - Returns true if the input sets have the same distinct elements. Accepts two or more argument expressions. +- [ ] $setIntersection - Returns a set with elements that appear in all of the input sets. Accepts any number of argument expressions. +- [ ] $setIsSubset - Returns true if all elements of the first set appear in the second set, including when the first set equals the second set; i.e. not a strict subset. Accepts exactly two argument expressions. +- [ ] $setUnion - Returns a set with elements that appear in any of the input sets. + +String Expression Operators + +- [ ] $concat - Concatenates any number of strings. +- [ ] $dateFromString - Converts a date/time string to a date object. +- [ ] $dateToString - Returns the date as a formatted string. +- [ ] $indexOfBytes - Searches a string for an occurrence of a substring and returns the UTF-8 byte index of the first occurrence. If the substring is not found, returns -1. +- [ ] $indexOfCP - Searches a string for an occurrence of a substring and returns the UTF-8 code point index of the first occurrence. If the substring is not found, returns -1 +- [ ] $ltrim - Removes whitespace or the specified characters from the beginning of a string. +- [ ] $regexFind - Applies a regular expression (regex) to a string and returns information on the first matched substring. +- [ ] $regexFindAll - Applies a regular expression (regex) to a string and returns information on the all matched substrings. +- [ ] $regexMatch - Applies a regular expression (regex) to a string and returns a boolean that indicates if a match is found or not. +- [ ] $replaceOne - Replaces the first instance of a matched string in a given input. +- [ ] $replaceAll - Replaces all instances of a matched string in a given input. +- [ ] $rtrim - Removes whitespace or the specified characters from the end of a string. +- [x] $split - Splits a string into substrings based on a delimiter. Returns an array of substrings. If the delimiter is not found within the string, returns an array containing the original string. +- [ ] $strLenBytes - Returns the number of UTF-8 encoded bytes in a string. +- [ ] $strLenCP - Returns the number of UTF-8 code points in a string. +- [ ] $strcasecmp - Performs case-insensitive string comparison and returns: 0 if two strings are equivalent, 1 if the first string is greater than the second, and -1 if the first string is less than the second. +- [ ] $substr - Deprecated. Use $substrBytes or $substrCP. +- [ ] $substrBytes - Returns the substring of a string. Starts with the character at the specified UTF-8 byte index (zero-based) in the string and continues for the specified number of bytes. +- [ ] $substrCP - Returns the substring of a string. Starts with the character at the specified UTF-8 code point (CP) +index (zero-based) in the string and continues for the number of code points specified. +- [ ] $toLower - Converts a string to lowercase. Accepts a single argument expression. +- [ ] $toString - Converts value to a string. +- [ ] $trim - Removes whitespace or the specified characters from the beginning and end of a string. +- [ ] $toUpper - Converts a string to uppercase. Accepts a single argument expression. + +Text Expression Operator + +- [ ] $meta - Access available per-document metadata related to the aggregation operation. + +Timestamp Expression Operators + +- [ ] $tsIncrement - Returns the incrementing ordinal from a timestamp as a long. +- [ ] $tsSecond - Returns the seconds from a timestamp as a long. + +Trigonometry Expression Operators + +- [x] $sin - Returns the sine of a value that is measured in radians. +- [x] $cos - Returns the cosine of a value that is measured in radians. +- [x] $tan - Returns the tangent of a value that is measured in radians. +- [x] $asin - Returns the inverse sin (arc sine) of a value in radians. +- [x] $acos - Returns the inverse cosine (arc cosine) of a value in radians. +- [x] $atan - Returns the inverse tangent (arc tangent) of a value in radians. +- [ ] $atan2 - Returns the inverse tangent (arc tangent) of y / x in radians, where y and x are the first and second values passed to the expression respectively. +- [x] $asinh - Returns the inverse hyperbolic sine (hyperbolic arc sine) of a value in radians. +- [x] $acosh - Returns the inverse hyperbolic cosine (hyperbolic arc cosine) of a value in radians. +- [x] $atanh - Returns the inverse hyperbolic tangent (hyperbolic arc tangent) of a value in radians. +- [x] $sinh - Returns the hyperbolic sine of a value that is measured in radians. +- [x] $cosh - Returns the hyperbolic cosine of a value that is measured in radians. +- [x] $tanh - Returns the hyperbolic tangent of a value that is measured in radians. +- [ ] $degreesToRadians - Converts a value from degrees to radians. +- [ ] $radiansToDegrees - Converts a value from radians to degrees. + +Type Expression Operators + +- [ ] $convert - Converts a value to a specified type. +- [ ] $isNumber - Returns boolean true if the specified expression resolves to an integer, decimal, double, or long. +- [ ] $toBool - Converts value to a boolean. +- [ ] $toDate - Converts value to a Date. +- [ ] $toDecimal - Converts value to a Decimal128. +- [ ] $toDouble - Converts value to a double. +- [ ] $toInt - Converts value to an integer. +- [ ] $toLong - Converts value to a long. +- [ ] $toObjectId - Converts value to an ObjectId. +- [ ] $toString - Converts value to a string. +- [ ] $type - Return the BSON data type of the field. +- [ ] $toUUID - Converts a string to a UUID. + +Accumulators ($group, $bucket, $bucketAuto, $setWindowFields) + +- [ ] $accumulator - Returns the result of a user-defined accumulator function. +- [ ] $addToSet - Returns an array of unique expression values for each group. Order of the array elements is undefined. +- [x] $avg - Returns an average of numerical values. Ignores non-numeric values. +- [ ] $bottom - Returns the bottom element within a group according to the specified sort order. +- [ ] $bottomN - Returns an aggregation of the bottom n fields within a group, according to the specified sort order. +- [x] $count - Returns the number of documents in a group. +- [ ] $first - Returns the result of an expression for the first document in a group. +- [ ] $firstN - Returns an aggregation of the first n elements within a group. Only meaningful when documents are in a defined order. Distinct from the $firstN array operator. +- [ ] $last - Returns the result of an expression for the last document in a group. +- [ ] $lastN - Returns an aggregation of the last n elements within a group. Only meaningful when documents are in a defined order. Distinct from the $lastN array operator. +- [x] $max - Returns the highest expression value for each group. +- [ ] $maxN - Returns an aggregation of the n maximum valued elements in a group. Distinct from the $maxN array operator. +- [ ] $median - Returns an approximation of the median, the 50th percentile, as a scalar value. +- [ ] $mergeObjects - Returns a document created by combining the input documents for each group. +- [x] $min - Returns the lowest expression value for each group. +- [ ] $minN - Returns an aggregation of the n minimum valued elements in a group. Distinct from the $minN array operator. +- [ ] $percentile - Returns an array of scalar values that correspond to specified percentile values. +- [x] $push - Returns an array of expression values for documents in each group. +- [ ] $stdDevPop - Returns the population standard deviation of the input values. +- [ ] $stdDevSamp - Returns the sample standard deviation of the input values. +- [x] $sum - Returns a sum of numerical values. Ignores non-numeric values. +- [ ] $top - Returns the top element within a group according to the specified sort order. +- [ ] $topN - Returns an aggregation of the top n fields within a group, according to the specified sort order. + +Accumulators (in Other Stages) + +- [ ] $avg - Returns an average of the specified expression or list of expressions for each document. Ignores non-numeric values. +- [ ] $first - Returns the result of an expression for the first document in a group. +- [ ] $last - Returns the result of an expression for the last document in a group. +- [ ] $max - Returns the maximum of the specified expression or list of expressions for each document +- [ ] $median - Returns an approximation of the median, the 50th percentile, as a scalar value. +- [ ] $min - Returns the minimum of the specified expression or list of expressions for each document +- [ ] $percentile - Returns an array of scalar values that correspond to specified percentile values. +- [ ] $stdDevPop - Returns the population standard deviation of the input values. +- [ ] $stdDevSamp - Returns the sample standard deviation of the input values. +- [ ] $sum - Returns a sum of numerical values. Ignores non-numeric values. + +Variable Expression Operators + +- [ ] $let - Defines variables for use within the scope of a subexpression and returns the result of the subexpression. Accepts named parameters. + +Window Operators + +- [ ] $addToSet - Returns an array of all unique values that results from applying an expression to each document. +- [ ] $avg - Returns the average for the specified expression. Ignores non-numeric values. +- [ ] $bottom - Returns the bottom element within a group according to the specified sort order. +- [ ] $bottomN - Returns an aggregation of the bottom n fields within a group, according to the specified sort order. +- [ ] $count - Returns the number of documents in the group or window. +- [ ] $covariancePop - Returns the population covariance of two numeric expressions. +- [ ] $covarianceSamp - Returns the sample covariance of two numeric expressions. +- [ ] $denseRank - Returns the document position (known as the rank) relative to other documents in the $setWindowFields stage partition. There are no gaps in the ranks. Ties receive the same rank. +- [ ] $derivative - Returns the average rate of change within the specified window. +- [ ] $documentNumber - Returns the position of a document (known as the document number) in the $setWindowFields stage partition. Ties result in different adjacent document numbers. +- [ ] $expMovingAvg - Returns the exponential moving average for the numeric expression. +- [ ] $first - Returns the result of an expression for the first document in a group or window. +- [ ] $integral - Returns the approximation of the area under a curve. +- [ ] $last - Returns the result of an expression for the last document in a group or window. +- [ ] $linearFill - Fills null and missing fields in a window using linear interpolation +- [ ] $locf - Last observation carried forward. Sets values for null and missing fields in a window to the last non-null value for the field. +- [ ] $max - Returns the maximum value that results from applying an expression to each document. +- [ ] $min - Returns the minimum value that results from applying an expression to each document. +- [ ] $minN - Returns an aggregation of the n minimum valued elements in a group. Distinct from the $minN array operator. +- [ ] $push - Returns an array of values that result from applying an expression to each document. +- [ ] $rank - Returns the document position (known as the rank) relative to other documents in the $setWindowFields stage partition. +- [ ] $shift - Returns the value from an expression applied to a document in a specified position relative to the current document in the $setWindowFields stage partition. +- [ ] $stdDevPop - Returns the population standard deviation that results from applying a numeric expression to each document. +- [ ] $stdDevSamp - Returns the sample standard deviation that results from applying a numeric expression to each document. +- [ ] $sum - Returns the sum that results from applying a numeric expression to each document. +- [ ] $top - Returns the top element within a group according to the specified sort order. +- [ ] $topN - Returns an aggregation of the top n fields within a group, according to the specified sort order. + diff --git a/crates/cli/src/native_query/aggregation_expression.rs b/crates/cli/src/native_query/aggregation_expression.rs index 7e7fa6ea..8d9190c8 100644 --- a/crates/cli/src/native_query/aggregation_expression.rs +++ b/crates/cli/src/native_query/aggregation_expression.rs @@ -11,46 +11,98 @@ use super::error::{Error, Result}; use super::reference_shorthand::{parse_reference_shorthand, Reference}; use super::type_constraint::{ObjectTypeConstraint, TypeConstraint, Variance}; +use TypeConstraint as C; + pub fn infer_type_from_aggregation_expression( context: &mut PipelineTypeContext<'_>, desired_object_type_name: &str, - bson: Bson, + type_hint: Option<&TypeConstraint>, + expression: Bson, ) -> Result { - let t = match bson { - Bson::Double(_) => TypeConstraint::Scalar(BsonScalarType::Double), - Bson::String(string) => infer_type_from_reference_shorthand(context, &string)?, - Bson::Array(_) => todo!("array type"), - Bson::Document(doc) => { - infer_type_from_aggregation_expression_document(context, desired_object_type_name, doc)? - } - Bson::Boolean(_) => TypeConstraint::Scalar(BsonScalarType::Bool), - Bson::Null | Bson::Undefined => { - let type_variable = context.new_type_variable(Variance::Covariant, []); - TypeConstraint::Nullable(Box::new(TypeConstraint::Variable(type_variable))) - } - Bson::RegularExpression(_) => TypeConstraint::Scalar(BsonScalarType::Regex), - Bson::JavaScriptCode(_) => TypeConstraint::Scalar(BsonScalarType::Javascript), - Bson::JavaScriptCodeWithScope(_) => { - TypeConstraint::Scalar(BsonScalarType::JavascriptWithScope) - } - Bson::Int32(_) => TypeConstraint::Scalar(BsonScalarType::Int), - Bson::Int64(_) => TypeConstraint::Scalar(BsonScalarType::Long), - Bson::Timestamp(_) => TypeConstraint::Scalar(BsonScalarType::Timestamp), - Bson::Binary(_) => TypeConstraint::Scalar(BsonScalarType::BinData), - Bson::ObjectId(_) => TypeConstraint::Scalar(BsonScalarType::ObjectId), - Bson::DateTime(_) => TypeConstraint::Scalar(BsonScalarType::Date), - Bson::Symbol(_) => TypeConstraint::Scalar(BsonScalarType::Symbol), - Bson::Decimal128(_) => TypeConstraint::Scalar(BsonScalarType::Decimal), - Bson::MaxKey => TypeConstraint::Scalar(BsonScalarType::MaxKey), - Bson::MinKey => TypeConstraint::Scalar(BsonScalarType::MinKey), - Bson::DbPointer(_) => TypeConstraint::Scalar(BsonScalarType::DbPointer), + let t = match expression { + Bson::Double(_) => C::Scalar(BsonScalarType::Double), + Bson::String(string) => infer_type_from_reference_shorthand(context, type_hint, &string)?, + Bson::Array(elems) => { + infer_type_from_array(context, desired_object_type_name, type_hint, elems)? + } + Bson::Document(doc) => infer_type_from_aggregation_expression_document( + context, + desired_object_type_name, + type_hint, + doc, + )?, + Bson::Boolean(_) => C::Scalar(BsonScalarType::Bool), + Bson::Null | Bson::Undefined => C::Scalar(BsonScalarType::Null), + Bson::RegularExpression(_) => C::Scalar(BsonScalarType::Regex), + Bson::JavaScriptCode(_) => C::Scalar(BsonScalarType::Javascript), + Bson::JavaScriptCodeWithScope(_) => C::Scalar(BsonScalarType::JavascriptWithScope), + Bson::Int32(_) => C::Scalar(BsonScalarType::Int), + Bson::Int64(_) => C::Scalar(BsonScalarType::Long), + Bson::Timestamp(_) => C::Scalar(BsonScalarType::Timestamp), + Bson::Binary(_) => C::Scalar(BsonScalarType::BinData), + Bson::ObjectId(_) => C::Scalar(BsonScalarType::ObjectId), + Bson::DateTime(_) => C::Scalar(BsonScalarType::Date), + Bson::Symbol(_) => C::Scalar(BsonScalarType::Symbol), + Bson::Decimal128(_) => C::Scalar(BsonScalarType::Decimal), + Bson::MaxKey => C::Scalar(BsonScalarType::MaxKey), + Bson::MinKey => C::Scalar(BsonScalarType::MinKey), + Bson::DbPointer(_) => C::Scalar(BsonScalarType::DbPointer), }; Ok(t) } +pub fn infer_types_from_aggregation_expression_tuple( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + type_hint_for_elements: Option<&TypeConstraint>, + bson: Bson, +) -> Result> { + let tuple = match bson { + Bson::Array(exprs) => exprs + .into_iter() + .map(|expr| { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + type_hint_for_elements, + expr, + ) + }) + .collect::>>()?, + expr => Err(Error::Other(format!("expected array, but got {expr}")))?, + }; + Ok(tuple) +} + +fn infer_type_from_array( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + type_hint_for_entire_array: Option<&TypeConstraint>, + elements: Vec, +) -> Result { + let elem_type_hint = type_hint_for_entire_array.map(|hint| match hint { + C::ArrayOf(t) => *t.clone(), + t => C::ElementOf(Box::new(t.clone())), + }); + Ok(C::Union( + elements + .into_iter() + .map(|elem| { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + elem_type_hint.as_ref(), + elem, + ) + }) + .collect::>()?, + )) +} + fn infer_type_from_aggregation_expression_document( context: &mut PipelineTypeContext<'_>, desired_object_type_name: &str, + type_hint_for_entire_object: Option<&TypeConstraint>, mut document: Document, ) -> Result { let mut expression_operators = document @@ -66,6 +118,7 @@ fn infer_type_from_aggregation_expression_document( infer_type_from_operator_expression( context, desired_object_type_name, + type_hint_for_entire_object, &operator, operands, ) @@ -74,21 +127,185 @@ fn infer_type_from_aggregation_expression_document( } } +// TODO: propagate expected type based on operator used fn infer_type_from_operator_expression( - _context: &mut PipelineTypeContext<'_>, - _desired_object_type_name: &str, + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + type_hint: Option<&TypeConstraint>, operator: &str, - operands: Bson, + operand: Bson, ) -> Result { - let t = match (operator, operands) { - ("$split", _) => { - TypeConstraint::ArrayOf(Box::new(TypeConstraint::Scalar(BsonScalarType::String))) + // NOTE: It is important to run inference on `operand` in every match arm even if we don't read + // the result because we need to check for uses of parameters. + let t = match operator { + // technically $abs returns the same *numeric* type as its input, and fails on other types + "$abs" => infer_type_from_aggregation_expression( + context, + desired_object_type_name, + type_hint.or(Some(&C::numeric())), + operand, + )?, + "$sin" | "$cos" | "$tan" | "$asin" | "$acos" | "$atan" | "$asinh" | "$acosh" | "$atanh" + | "$sinh" | "$cosh" | "$tanh" => { + type_for_trig_operator(infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::numeric()), + operand, + )?) + } + "$add" | "$divide" | "$multiply" | "$subtract" => homogeneous_binary_operator_operand_type( + context, + desired_object_type_name, + Some(C::numeric()), + operator, + operand, + )?, + "$and" | "$or" => { + infer_types_from_aggregation_expression_tuple( + context, + desired_object_type_name, + None, + operand, + )?; + C::Scalar(BsonScalarType::Bool) + } + "$not" => { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::Scalar(BsonScalarType::Bool)), + operand, + )?; + C::Scalar(BsonScalarType::Bool) + } + "$eq" | "$ne" => { + homogeneous_binary_operator_operand_type( + context, + desired_object_type_name, + None, + operator, + operand, + )?; + C::Scalar(BsonScalarType::Bool) + } + "$gt" | "$gte" | "$lt" | "$lte" => { + homogeneous_binary_operator_operand_type( + context, + desired_object_type_name, + Some(C::comparable()), + operator, + operand, + )?; + C::Scalar(BsonScalarType::Bool) + } + "$allElementsTrue" => { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::ArrayOf(Box::new(C::Scalar(BsonScalarType::Bool)))), + operand, + )?; + C::Scalar(BsonScalarType::Bool) + } + "$anyElementTrue" => { + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::ArrayOf(Box::new(C::Scalar(BsonScalarType::Bool)))), + operand, + )?; + C::Scalar(BsonScalarType::Bool) } - (op, _) => Err(Error::UnknownAggregationOperator(op.to_string()))?, + "$arrayElemAt" => { + let (array_ref, idx) = two_parameter_operand(operator, operand)?; + let array_type = infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_arrayElemAt_array"), + type_hint.map(|t| C::ArrayOf(Box::new(t.clone()))).as_ref(), + array_ref, + )?; + infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_arrayElemAt_idx"), + Some(&C::Scalar(BsonScalarType::Int)), + idx, + )?; + type_hint + .cloned() + .unwrap_or_else(|| C::ElementOf(Box::new(array_type))) + .make_nullable() + } + "$split" => { + infer_types_from_aggregation_expression_tuple( + context, + desired_object_type_name, + Some(&C::Scalar(BsonScalarType::String)), + operand, + )?; + C::ArrayOf(Box::new(C::Scalar(BsonScalarType::String))) + } + op => Err(Error::UnknownAggregationOperator(op.to_string()))?, }; Ok(t) } +fn two_parameter_operand(operator: &str, operand: Bson) -> Result<(Bson, Bson)> { + match operand { + Bson::Array(operands) => { + if operands.len() != 2 { + return Err(Error::Other(format!( + "argument to {operator} must be a two-element array" + ))); + } + let mut operands = operands.into_iter(); + let a = operands.next().unwrap(); + let b = operands.next().unwrap(); + Ok((a, b)) + } + other_bson => Err(Error::ExpectedArrayExpressionArgument { + actual_argument: other_bson, + })?, + } +} + +fn homogeneous_binary_operator_operand_type( + context: &mut PipelineTypeContext<'_>, + desired_object_type_name: &str, + operand_type_hint: Option, + operator: &str, + operand: Bson, +) -> Result { + let (a, b) = two_parameter_operand(operator, operand)?; + let variable = context.new_type_variable(Variance::Invariant, operand_type_hint); + let type_a = infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::Variable(variable)), + a, + )?; + let type_b = infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&C::Variable(variable)), + b, + )?; + for t in [type_a, type_b] { + // Avoid cycles of type variable references + if !context.constraint_references_variable(&t, variable) { + context.set_type_variable_constraint(variable, t); + } + } + Ok(C::Variable(variable)) +} + +pub fn type_for_trig_operator(operand_type: TypeConstraint) -> TypeConstraint { + operand_type.map_nullable(|t| match t { + t @ C::Scalar(BsonScalarType::Decimal) => t, + _ => C::Scalar(BsonScalarType::Double), + }) +} + /// This is a document that is not evaluated as a plain value, not as an aggregation expression. fn infer_type_from_document( context: &mut PipelineTypeContext<'_>, @@ -100,18 +317,23 @@ fn infer_type_from_document( .into_iter() .map(|(field_name, bson)| { let field_object_type_name = format!("{desired_object_type_name}_{field_name}"); - let object_field_type = - infer_type_from_aggregation_expression(context, &field_object_type_name, bson)?; + let object_field_type = infer_type_from_aggregation_expression( + context, + &field_object_type_name, + None, + bson, + )?; Ok((field_name.into(), object_field_type)) }) .collect::>>()?; let object_type = ObjectTypeConstraint { fields }; context.insert_object_type(object_type_name.clone(), object_type); - Ok(TypeConstraint::Object(object_type_name)) + Ok(C::Object(object_type_name)) } pub fn infer_type_from_reference_shorthand( context: &mut PipelineTypeContext<'_>, + type_hint: Option<&TypeConstraint>, input: &str, ) -> Result { let reference = parse_reference_shorthand(input)?; @@ -121,17 +343,16 @@ pub fn infer_type_from_reference_shorthand( type_annotation: _, } => { // TODO: read type annotation ENG-1249 - // TODO: set constraint based on expected type here like we do in match_stage.rs NDC-1251 - context.register_parameter(name.into(), []) + context.register_parameter(name.into(), type_hint.into_iter().cloned()) } - Reference::PipelineVariable { .. } => todo!(), + Reference::PipelineVariable { .. } => todo!("pipeline variable"), Reference::InputDocumentField { name, nested_path } => { let doc_type = context.get_input_document_type()?; let path = NonEmpty { head: name, tail: nested_path, }; - TypeConstraint::FieldOf { + C::FieldOf { target_type: Box::new(doc_type.clone()), path, } @@ -140,13 +361,57 @@ pub fn infer_type_from_reference_shorthand( native_query_variables, } => { for variable in native_query_variables { - context.register_parameter( - variable.into(), - [TypeConstraint::Scalar(BsonScalarType::String)], - ); + context.register_parameter(variable.into(), [C::Scalar(BsonScalarType::String)]); } - TypeConstraint::Scalar(BsonScalarType::String) + C::Scalar(BsonScalarType::String) } }; Ok(t) } + +#[cfg(test)] +mod tests { + use googletest::prelude::*; + use mongodb::bson::bson; + use mongodb_support::BsonScalarType; + use test_helpers::configuration::mflix_config; + + use crate::native_query::{ + pipeline_type_context::PipelineTypeContext, + type_constraint::{TypeConstraint, TypeVariable, Variance}, + }; + + use super::infer_type_from_operator_expression; + + use TypeConstraint as C; + + #[googletest::test] + fn infers_constrants_on_equality() -> Result<()> { + let config = mflix_config(); + let mut context = PipelineTypeContext::new(&config, None); + + let (var0, var1) = ( + TypeVariable::new(0, Variance::Invariant), + TypeVariable::new(1, Variance::Contravariant), + ); + + infer_type_from_operator_expression( + &mut context, + "test", + None, + "$eq", + bson!(["{{ parameter }}", 1]), + )?; + + expect_eq!( + context.type_variables(), + &[ + (var0, [C::Scalar(BsonScalarType::Int)].into()), + (var1, [C::Variable(var0)].into()) + ] + .into() + ); + + Ok(()) + } +} diff --git a/crates/cli/src/native_query/error.rs b/crates/cli/src/native_query/error.rs index 40c26217..5398993a 100644 --- a/crates/cli/src/native_query/error.rs +++ b/crates/cli/src/native_query/error.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use configuration::schema::Type; use mongodb::bson::{self, Bson, Document}; @@ -25,6 +25,9 @@ pub enum Error { #[error("Expected an array type, but got: {actual_type:?}")] ExpectedArray { actual_type: Type }, + #[error("Expected an array, but got: {actual_argument}")] + ExpectedArrayExpressionArgument { actual_argument: Bson }, + #[error("Expected an object type, but got: {actual_type:?}")] ExpectedObject { actual_type: Type }, @@ -68,20 +71,20 @@ pub enum Error { could_not_infer_return_type: bool, // These fields are included here for internal debugging - type_variables: HashMap>, + type_variables: HashMap>, object_type_constraints: BTreeMap, }, #[error("Error parsing a string in the aggregation pipeline: {0}")] UnableToParseReferenceShorthand(String), - #[error("Unknown match document operator: {0}")] + #[error("Type inference is not currently implemented for the query document operator, {0}. Please file a bug report, and declare types for your native query by hand for the time being.")] UnknownMatchDocumentOperator(String), - #[error("Unknown aggregation operator: {0}")] + #[error("Type inference is not currently implemented for the aggregation expression operator, {0}. Please file a bug report, and declare types for your native query by hand for the time being.")] UnknownAggregationOperator(String), - #[error("Type inference is not currently implemented for stage {stage_index} in the aggregation pipeline. Please file a bug report, and declare types for your native query by hand.\n\n{stage}")] + #[error("Type inference is not currently implemented for {stage}, stage number {} in your aggregation pipeline. Please file a bug report, and declare types for your native query by hand for the time being.", stage_index + 1)] UnknownAggregationStage { stage_index: usize, stage: bson::Document, @@ -92,6 +95,12 @@ pub enum Error { #[error("Unknown object type, \"{0}\"")] UnknownObjectType(String), + + #[error("{0}")] + Other(String), + + #[error("Errors processing pipeline:\n\n{}", multiple_errors(.0))] + Multiple(Vec), } fn unable_to_infer_types_message( @@ -116,3 +125,11 @@ fn unable_to_infer_types_message( } message } + +fn multiple_errors(errors: &[Error]) -> String { + let mut output = String::new(); + for error in errors { + output += &format!("- {}\n", error); + } + output +} diff --git a/crates/cli/src/native_query/helpers.rs b/crates/cli/src/native_query/helpers.rs index 3a2d10c0..d39ff44e 100644 --- a/crates/cli/src/native_query/helpers.rs +++ b/crates/cli/src/native_query/helpers.rs @@ -1,7 +1,8 @@ use std::{borrow::Cow, collections::BTreeMap}; use configuration::Configuration; -use ndc_models::{CollectionInfo, CollectionName, ObjectTypeName}; +use ndc_models::{CollectionInfo, CollectionName, FieldName, ObjectTypeName}; +use nonempty::NonEmpty; use regex::Regex; use super::error::{Error, Result}; @@ -56,3 +57,38 @@ pub fn parse_counter_suffix(name: &str) -> (Cow<'_, str>, u32) { }; (Cow::Owned(prefix.to_string()), count) } + +pub fn get_object_field_type<'a>( + object_types: &'a BTreeMap, + object_type_name: &ObjectTypeName, + object_type: &'a ndc_models::ObjectType, + path: NonEmpty, +) -> Result<&'a ndc_models::Type> { + let field_name = path.head; + let rest = NonEmpty::from_vec(path.tail); + + let field = object_type + .fields + .get(&field_name) + .ok_or_else(|| Error::ObjectMissingField { + object_type: object_type_name.clone(), + field_name: field_name.clone(), + })?; + + match rest { + None => Ok(&field.r#type), + Some(rest) => match &field.r#type { + ndc_models::Type::Named { name } => { + let type_name: ObjectTypeName = name.clone().into(); + let inner_object_type = object_types + .get(&type_name) + .ok_or_else(|| Error::UnknownObjectType(type_name.to_string()))?; + get_object_field_type(object_types, &type_name, inner_object_type, rest) + } + _ => Err(Error::ObjectMissingField { + object_type: object_type_name.clone(), + field_name: field_name.clone(), + }), + }, + } +} diff --git a/crates/cli/src/native_query/mod.rs b/crates/cli/src/native_query/mod.rs index 0616c6a2..2ddac4c5 100644 --- a/crates/cli/src/native_query/mod.rs +++ b/crates/cli/src/native_query/mod.rs @@ -8,6 +8,9 @@ mod reference_shorthand; mod type_constraint; mod type_solver; +#[cfg(test)] +mod tests; + use std::path::{Path, PathBuf}; use std::process::exit; @@ -176,171 +179,3 @@ pub fn native_query_from_pipeline( description: None, }) } - -#[cfg(test)] -mod tests { - use anyhow::Result; - use configuration::{ - native_query::NativeQueryRepresentation::Collection, - read_directory, - schema::{ObjectField, ObjectType, Type}, - serialized::NativeQuery, - Configuration, - }; - use googletest::prelude::*; - use mongodb::bson::doc; - use mongodb_support::{ - aggregate::{Accumulator, Pipeline, Selection, Stage}, - BsonScalarType, - }; - use ndc_models::ObjectTypeName; - use pretty_assertions::assert_eq; - use test_helpers::configuration::mflix_config; - - use super::native_query_from_pipeline; - - #[tokio::test] - async fn infers_native_query_from_pipeline() -> Result<()> { - let config = read_configuration().await?; - let pipeline = Pipeline::new(vec![Stage::Documents(vec![ - doc! { "foo": 1 }, - doc! { "bar": 2 }, - ])]); - let native_query = native_query_from_pipeline( - &config, - "selected_title", - Some("movies".into()), - pipeline.clone(), - )?; - - let expected_document_type_name: ObjectTypeName = "selected_title_documents".into(); - - let expected_object_types = [( - expected_document_type_name.clone(), - ObjectType { - fields: [ - ( - "foo".into(), - ObjectField { - r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), - description: None, - }, - ), - ( - "bar".into(), - ObjectField { - r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), - description: None, - }, - ), - ] - .into(), - description: None, - }, - )] - .into(); - - let expected = NativeQuery { - representation: Collection, - input_collection: Some("movies".into()), - arguments: Default::default(), - result_document_type: expected_document_type_name, - object_types: expected_object_types, - pipeline: pipeline.into(), - description: None, - }; - - assert_eq!(native_query, expected); - Ok(()) - } - - #[tokio::test] - async fn infers_native_query_from_non_trivial_pipeline() -> Result<()> { - let config = read_configuration().await?; - let pipeline = Pipeline::new(vec![ - Stage::ReplaceWith(Selection::new(doc! { - "title": "$title", - "title_words": { "$split": ["$title", " "] } - })), - Stage::Unwind { - path: "$title_words".to_string(), - include_array_index: None, - preserve_null_and_empty_arrays: None, - }, - Stage::Group { - key_expression: "$title_words".into(), - accumulators: [("title_count".into(), Accumulator::Count)].into(), - }, - ]); - let native_query = native_query_from_pipeline( - &config, - "title_word_frequency", - Some("movies".into()), - pipeline.clone(), - )?; - - assert_eq!(native_query.input_collection, Some("movies".into())); - assert!(native_query - .result_document_type - .to_string() - .starts_with("title_word_frequency")); - assert_eq!( - native_query - .object_types - .get(&native_query.result_document_type), - Some(&ObjectType { - fields: [ - ( - "_id".into(), - ObjectField { - r#type: Type::Scalar(BsonScalarType::String), - description: None, - }, - ), - ( - "title_count".into(), - ObjectField { - r#type: Type::Scalar(BsonScalarType::Int), - description: None, - }, - ), - ] - .into(), - description: None, - }) - ); - Ok(()) - } - - #[googletest::test] - fn infers_native_query_from_pipeline_with_unannotated_parameter() -> googletest::Result<()> { - let config = mflix_config(); - - let pipeline = Pipeline::new(vec![Stage::Match(doc! { - "title": { "$eq": "{{ title }}" }, - })]); - - let native_query = native_query_from_pipeline( - &config, - "movies_by_title", - Some("movies".into()), - pipeline, - )?; - - expect_that!( - native_query.arguments, - unordered_elements_are![( - displays_as(eq("title")), - field!( - ObjectField.r#type, - eq(&Type::Scalar(BsonScalarType::String)) - ) - )] - ); - Ok(()) - } - - async fn read_configuration() -> Result { - read_directory("../../fixtures/hasura/sample_mflix/connector").await - } -} diff --git a/crates/cli/src/native_query/pipeline/match_stage.rs b/crates/cli/src/native_query/pipeline/match_stage.rs index 8246ad4b..18165fdf 100644 --- a/crates/cli/src/native_query/pipeline/match_stage.rs +++ b/crates/cli/src/native_query/pipeline/match_stage.rs @@ -1,4 +1,5 @@ use mongodb::bson::{Bson, Document}; +use mongodb_support::BsonScalarType; use nonempty::nonempty; use crate::native_query::{ @@ -16,7 +17,13 @@ pub fn check_match_doc_for_parameters( ) -> Result<()> { let input_document_type = context.get_input_document_type()?; if let Some(expression) = match_doc.remove("$expr") { - infer_type_from_aggregation_expression(context, desired_object_type_name, expression)?; + let type_hint = TypeConstraint::Scalar(BsonScalarType::Bool); + infer_type_from_aggregation_expression( + context, + desired_object_type_name, + Some(&type_hint), + expression, + )?; Ok(()) } else { check_match_doc_for_parameters_helper( diff --git a/crates/cli/src/native_query/pipeline/mod.rs b/crates/cli/src/native_query/pipeline/mod.rs index 144289b7..fad8853b 100644 --- a/crates/cli/src/native_query/pipeline/mod.rs +++ b/crates/cli/src/native_query/pipeline/mod.rs @@ -13,6 +13,7 @@ use ndc_models::{CollectionName, FieldName, ObjectTypeName}; use super::{ aggregation_expression::{ self, infer_type_from_aggregation_expression, infer_type_from_reference_shorthand, + type_for_trig_operator, }, error::{Error, Result}, helpers::find_collection_object_type, @@ -75,6 +76,7 @@ fn infer_stage_output_type( infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_documents"), + None, doc.into(), ) }) @@ -114,6 +116,7 @@ fn infer_stage_output_type( aggregation_expression::infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_replaceWith"), + None, selection.clone().into(), )?, ) @@ -152,6 +155,7 @@ fn infer_type_from_group_stage( let group_key_expression_type = infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_id"), + None, key_expression.clone(), )?; @@ -164,17 +168,20 @@ fn infer_type_from_group_stage( Accumulator::Min(expr) => infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_min"), + None, expr.clone(), )?, Accumulator::Max(expr) => infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_min"), + None, expr.clone(), )?, Accumulator::Push(expr) => { let t = infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_push"), + None, expr.clone(), )?; TypeConstraint::ArrayOf(Box::new(t)) @@ -183,28 +190,17 @@ fn infer_type_from_group_stage( let t = infer_type_from_aggregation_expression( context, &format!("{desired_object_type_name}_avg"), + Some(&TypeConstraint::numeric()), expr.clone(), )?; - match t { - TypeConstraint::ExtendedJSON => t, - TypeConstraint::Scalar(scalar_type) if scalar_type.is_numeric() => t, - _ => TypeConstraint::Nullable(Box::new(TypeConstraint::Scalar( - BsonScalarType::Int, - ))), - } - } - Accumulator::Sum(expr) => { - let t = infer_type_from_aggregation_expression( - context, - &format!("{desired_object_type_name}_push"), - expr.clone(), - )?; - match t { - TypeConstraint::ExtendedJSON => t, - TypeConstraint::Scalar(scalar_type) if scalar_type.is_numeric() => t, - _ => TypeConstraint::Scalar(BsonScalarType::Int), - } + type_for_trig_operator(t).make_nullable() } + Accumulator::Sum(expr) => infer_type_from_aggregation_expression( + context, + &format!("{desired_object_type_name}_push"), + Some(&TypeConstraint::numeric()), + expr.clone(), + )?, }; Ok::<_, Error>((key.clone().into(), accumulator_type)) }); @@ -229,7 +225,7 @@ fn infer_type_from_unwind_stage( let Reference::InputDocumentField { name, nested_path } = field_to_unwind else { return Err(Error::ExpectedStringPath(path.into())); }; - let field_type = infer_type_from_reference_shorthand(context, path)?; + let field_type = infer_type_from_reference_shorthand(context, None, path)?; let mut unwind_stage_object_type = ObjectTypeConstraint { fields: Default::default(), diff --git a/crates/cli/src/native_query/pipeline_type_context.rs b/crates/cli/src/native_query/pipeline_type_context.rs index 3f8e3ae0..56fe56a3 100644 --- a/crates/cli/src/native_query/pipeline_type_context.rs +++ b/crates/cli/src/native_query/pipeline_type_context.rs @@ -2,7 +2,7 @@ use std::{ borrow::Cow, - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap}, }; use configuration::{ @@ -43,7 +43,7 @@ pub struct PipelineTypeContext<'a> { /// to a type here, or in [self.configuration.object_types] object_types: BTreeMap, - type_variables: HashMap>, + type_variables: HashMap>, next_type_variable: u32, warnings: Vec, @@ -71,6 +71,11 @@ impl PipelineTypeContext<'_> { context } + #[cfg(test)] + pub fn type_variables(&self) -> &HashMap> { + &self.type_variables + } + pub fn into_types(self) -> Result { let result_document_type_variable = self.input_doc_type.ok_or(Error::IncompletePipeline)?; let required_type_variables = self @@ -80,6 +85,15 @@ impl PipelineTypeContext<'_> { .chain([result_document_type_variable]) .collect_vec(); + #[cfg(test)] + { + println!("variable mappings:"); + for (parameter, variable) in self.parameter_types.iter() { + println!(" {variable}: {parameter}"); + } + println!(" {result_document_type_variable}: result type\n"); + } + let mut object_type_constraints = self.object_types; let (variable_types, added_object_types) = unify( self.configuration, @@ -177,6 +191,60 @@ impl PipelineTypeContext<'_> { entry.insert(constraint); } + pub fn constraint_references_variable( + &self, + constraint: &TypeConstraint, + variable: TypeVariable, + ) -> bool { + let object_constraint_references_variable = |name: &ObjectTypeName| -> bool { + if let Some(object_type) = self.object_types.get(name) { + object_type.fields.iter().any(|(_, field_type)| { + self.constraint_references_variable(field_type, variable) + }) + } else { + false + } + }; + + match constraint { + TypeConstraint::ExtendedJSON => false, + TypeConstraint::Scalar(_) => false, + TypeConstraint::Object(name) => object_constraint_references_variable(name), + TypeConstraint::ArrayOf(t) => self.constraint_references_variable(t, variable), + TypeConstraint::Predicate { object_type_name } => { + object_constraint_references_variable(object_type_name) + } + TypeConstraint::Union(ts) => ts + .iter() + .any(|t| self.constraint_references_variable(t, variable)), + TypeConstraint::OneOf(ts) => ts + .iter() + .any(|t| self.constraint_references_variable(t, variable)), + TypeConstraint::Variable(v2) if *v2 == variable => true, + TypeConstraint::Variable(v2) => { + let constraints = self.type_variables.get(v2); + constraints + .iter() + .flat_map(|m| *m) + .any(|t| self.constraint_references_variable(t, variable)) + } + TypeConstraint::ElementOf(t) => self.constraint_references_variable(t, variable), + TypeConstraint::FieldOf { target_type, .. } => { + self.constraint_references_variable(target_type, variable) + } + TypeConstraint::WithFieldOverrides { + target_type, + fields, + .. + } => { + self.constraint_references_variable(target_type, variable) + || fields + .iter() + .any(|(_, t)| self.constraint_references_variable(t, variable)) + } + } + } + pub fn insert_object_type(&mut self, name: ObjectTypeName, object_type: ObjectTypeConstraint) { self.object_types.insert(name, object_type); } diff --git a/crates/cli/src/native_query/tests.rs b/crates/cli/src/native_query/tests.rs new file mode 100644 index 00000000..64540811 --- /dev/null +++ b/crates/cli/src/native_query/tests.rs @@ -0,0 +1,274 @@ +use std::collections::BTreeMap; + +use anyhow::Result; +use configuration::{ + native_query::NativeQueryRepresentation::Collection, + read_directory, + schema::{ObjectField, ObjectType, Type}, + serialized::NativeQuery, + Configuration, +}; +use googletest::prelude::*; +use mongodb::bson::doc; +use mongodb_support::{ + aggregate::{Accumulator, Pipeline, Selection, Stage}, + BsonScalarType, +}; +use ndc_models::ObjectTypeName; +use pretty_assertions::assert_eq; +use test_helpers::configuration::mflix_config; + +use super::native_query_from_pipeline; + +#[tokio::test] +async fn infers_native_query_from_pipeline() -> Result<()> { + let config = read_configuration().await?; + let pipeline = Pipeline::new(vec![Stage::Documents(vec![ + doc! { "foo": 1 }, + doc! { "bar": 2 }, + ])]); + let native_query = native_query_from_pipeline( + &config, + "selected_title", + Some("movies".into()), + pipeline.clone(), + )?; + + let expected_document_type_name: ObjectTypeName = "selected_title_documents".into(); + + let expected_object_types = [( + expected_document_type_name.clone(), + ObjectType { + fields: [ + ( + "foo".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ( + "bar".into(), + ObjectField { + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ), + ] + .into(), + description: None, + }, + )] + .into(); + + let expected = NativeQuery { + representation: Collection, + input_collection: Some("movies".into()), + arguments: Default::default(), + result_document_type: expected_document_type_name, + object_types: expected_object_types, + pipeline: pipeline.into(), + description: None, + }; + + assert_eq!(native_query, expected); + Ok(()) +} + +#[tokio::test] +async fn infers_native_query_from_non_trivial_pipeline() -> Result<()> { + let config = read_configuration().await?; + let pipeline = Pipeline::new(vec![ + Stage::ReplaceWith(Selection::new(doc! { + "title": "$title", + "title_words": { "$split": ["$title", " "] } + })), + Stage::Unwind { + path: "$title_words".to_string(), + include_array_index: None, + preserve_null_and_empty_arrays: None, + }, + Stage::Group { + key_expression: "$title_words".into(), + accumulators: [("title_count".into(), Accumulator::Count)].into(), + }, + ]); + let native_query = native_query_from_pipeline( + &config, + "title_word_frequency", + Some("movies".into()), + pipeline.clone(), + )?; + + assert_eq!(native_query.input_collection, Some("movies".into())); + assert!(native_query + .result_document_type + .to_string() + .starts_with("title_word_frequency")); + assert_eq!( + native_query + .object_types + .get(&native_query.result_document_type), + Some(&ObjectType { + fields: [ + ( + "_id".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::String), + description: None, + }, + ), + ( + "title_count".into(), + ObjectField { + r#type: Type::Scalar(BsonScalarType::Int), + description: None, + }, + ), + ] + .into(), + description: None, + }) + ); + Ok(()) +} + +#[googletest::test] +fn infers_native_query_from_pipeline_with_unannotated_parameter() -> googletest::Result<()> { + let config = mflix_config(); + + let pipeline = Pipeline::new(vec![Stage::Match(doc! { + "title": { "$eq": "{{ title }}" }, + })]); + + let native_query = + native_query_from_pipeline(&config, "movies_by_title", Some("movies".into()), pipeline)?; + + expect_that!( + native_query.arguments, + unordered_elements_are![( + displays_as(eq("title")), + field!( + ObjectField.r#type, + eq(&Type::Scalar(BsonScalarType::String)) + ) + )] + ); + Ok(()) +} + +#[googletest::test] +fn infers_parameter_type_from_binary_comparison() -> googletest::Result<()> { + let config = mflix_config(); + + let pipeline = Pipeline::new(vec![Stage::Match(doc! { + "$expr": { "$eq": ["{{ title }}", "$title"] } + })]); + + let native_query = + native_query_from_pipeline(&config, "movies_by_title", Some("movies".into()), pipeline)?; + + expect_that!( + native_query.arguments, + unordered_elements_are![( + displays_as(eq("title")), + field!( + ObjectField.r#type, + eq(&Type::Scalar(BsonScalarType::String)) + ) + )] + ); + Ok(()) +} + +#[googletest::test] +fn supports_various_aggregation_operators() -> googletest::Result<()> { + let config = mflix_config(); + + let pipeline = Pipeline::new(vec![ + Stage::Match(doc! { + "$expr": { + "$and": [ + { "$eq": ["{{ title }}", "$title"] }, + { "$or": [null, 1] }, + { "$not": "{{ bool_param }}" }, + { "$gt": ["$imdb.votes", "{{ votes }}"] }, + ] + } + }), + Stage::ReplaceWith(Selection::new(doc! { + "abs": { "$abs": "$year" }, + "add": { "$add": ["$tomatoes.viewer.rating", "{{ rating_inc }}"] }, + "divide": { "$divide": ["$tomatoes.viewer.rating", "{{ rating_div }}"] }, + "multiply": { "$multiply": ["$tomatoes.viewer.rating", "{{ rating_mult }}"] }, + "subtract": { "$subtract": ["$tomatoes.viewer.rating", "{{ rating_sub }}"] }, + "arrayElemAt": { "$arrayElemAt": ["$genres", "{{ idx }}"] }, + "title_words": { "$split": ["$title", " "] } + })), + ]); + + let native_query = + native_query_from_pipeline(&config, "operators_test", Some("movies".into()), pipeline)?; + + expect_eq!( + native_query.arguments, + object_fields([ + ("title", Type::Scalar(BsonScalarType::String)), + ("bool_param", Type::Scalar(BsonScalarType::Bool)), + ("votes", Type::Scalar(BsonScalarType::Int)), + ("rating_inc", Type::Scalar(BsonScalarType::Double)), + ("rating_div", Type::Scalar(BsonScalarType::Double)), + ("rating_mult", Type::Scalar(BsonScalarType::Double)), + ("rating_sub", Type::Scalar(BsonScalarType::Double)), + ("idx", Type::Scalar(BsonScalarType::Int)), + ]) + ); + + let result_type = native_query.result_document_type; + expect_eq!( + native_query.object_types[&result_type], + ObjectType { + fields: object_fields([ + ("abs", Type::Scalar(BsonScalarType::Int)), + ("add", Type::Scalar(BsonScalarType::Double)), + ("divide", Type::Scalar(BsonScalarType::Double)), + ("multiply", Type::Scalar(BsonScalarType::Double)), + ("subtract", Type::Scalar(BsonScalarType::Double)), + ( + "arrayElemAt", + Type::Nullable(Box::new(Type::Scalar(BsonScalarType::String))) + ), + ( + "title_words", + Type::ArrayOf(Box::new(Type::Scalar(BsonScalarType::String))) + ), + ]), + description: None, + } + ); + + Ok(()) +} + +fn object_fields(types: impl IntoIterator) -> BTreeMap +where + S: Into, + K: Ord, +{ + types + .into_iter() + .map(|(name, r#type)| { + ( + name.into(), + ObjectField { + r#type, + description: None, + }, + ) + }) + .collect() +} + +async fn read_configuration() -> Result { + read_directory("../../fixtures/hasura/sample_mflix/connector").await +} diff --git a/crates/cli/src/native_query/type_constraint.rs b/crates/cli/src/native_query/type_constraint.rs index d4ab667c..67c04156 100644 --- a/crates/cli/src/native_query/type_constraint.rs +++ b/crates/cli/src/native_query/type_constraint.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use configuration::MongoScalarType; use mongodb_support::BsonScalarType; @@ -6,7 +6,7 @@ use ndc_models::{FieldName, ObjectTypeName}; use nonempty::NonEmpty; use ref_cast::RefCast as _; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct TypeVariable { id: u32, pub variance: Variance, @@ -16,28 +16,57 @@ impl TypeVariable { pub fn new(id: u32, variance: Variance) -> Self { TypeVariable { id, variance } } + + pub fn is_covariant(self) -> bool { + matches!(self.variance, Variance::Covariant) + } + + pub fn is_contravariant(self) -> bool { + matches!(self.variance, Variance::Contravariant) + } +} + +impl std::fmt::Display for TypeVariable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "${}", self.id) + } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Variance { Covariant, Contravariant, + Invariant, } /// A TypeConstraint is almost identical to a [configuration::schema::Type], except that /// a TypeConstraint may reference type variables. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum TypeConstraint { // Normal type stuff - except that composite types might include variables in their structure. ExtendedJSON, Scalar(BsonScalarType), Object(ObjectTypeName), ArrayOf(Box), - Nullable(Box), Predicate { object_type_name: ObjectTypeName, }, + // Complex types + + Union(BTreeSet), + + /// Unlike Union we expect the solved concrete type for a variable with a OneOf constraint may + /// be one of the types in the set, but we don't know yet which one. This is useful for MongoDB + /// operators that expect an input of any numeric type. We use OneOf because we don't know + /// which numeric type to infer until we see more usage evidence of the same type variable. + /// + /// In other words with Union we have specific evidence that a variable occurs in contexts of + /// multiple concrete types, while with OneOf we **don't** have specific evidence that the + /// variable takes multiple types, but there are multiple possibilities of the type or types + /// that it does take. + OneOf(BTreeSet), + /// Indicates a type that is the same as the type of the given variable. Variable(TypeVariable), @@ -58,20 +87,30 @@ pub enum TypeConstraint { target_type: Box, fields: BTreeMap, }, - // TODO: Add Non-nullable constraint? } impl TypeConstraint { /// Order constraints by complexity to help with type unification pub fn complexity(&self) -> usize { match self { - TypeConstraint::Variable(_) => 0, + TypeConstraint::Variable(_) => 2, TypeConstraint::ExtendedJSON => 0, TypeConstraint::Scalar(_) => 0, TypeConstraint::Object(_) => 1, TypeConstraint::Predicate { .. } => 1, TypeConstraint::ArrayOf(constraint) => 1 + constraint.complexity(), - TypeConstraint::Nullable(constraint) => 1 + constraint.complexity(), + TypeConstraint::Union(constraints) => { + 1 + constraints + .iter() + .map(TypeConstraint::complexity) + .sum::() + } + TypeConstraint::OneOf(constraints) => { + 1 + constraints + .iter() + .map(TypeConstraint::complexity) + .sum::() + } TypeConstraint::ElementOf(constraint) => 2 + constraint.complexity(), TypeConstraint::FieldOf { target_type, path } => { 2 + target_type.complexity() + path.len() @@ -93,11 +132,84 @@ impl TypeConstraint { pub fn make_nullable(self) -> Self { match self { TypeConstraint::ExtendedJSON => TypeConstraint::ExtendedJSON, - TypeConstraint::Nullable(t) => TypeConstraint::Nullable(t), - TypeConstraint::Scalar(BsonScalarType::Null) => { - TypeConstraint::Scalar(BsonScalarType::Null) + t @ TypeConstraint::Scalar(BsonScalarType::Null) => t, + t => TypeConstraint::union(t, TypeConstraint::Scalar(BsonScalarType::Null)), + } + } + + pub fn null() -> Self { + TypeConstraint::Scalar(BsonScalarType::Null) + } + + pub fn is_nullable(&self) -> bool { + match self { + TypeConstraint::Union(types) => types + .iter() + .any(|t| matches!(t, TypeConstraint::Scalar(BsonScalarType::Null))), + _ => false, + } + } + + pub fn map_nullable(self, callback: F) -> TypeConstraint + where + F: FnOnce(TypeConstraint) -> TypeConstraint, + { + match self { + Self::Union(types) => { + let non_null_types: BTreeSet<_> = + types.into_iter().filter(|t| t != &Self::null()).collect(); + let single_non_null_type = if non_null_types.len() == 1 { + non_null_types.into_iter().next().unwrap() + } else { + Self::Union(non_null_types) + }; + let mapped = callback(single_non_null_type); + Self::union(mapped, Self::null()) } - t => TypeConstraint::Nullable(Box::new(t)), + t => callback(t), + } + } + + fn scalar_one_of_by_predicate(f: impl Fn(BsonScalarType) -> bool) -> TypeConstraint { + let matching_types = enum_iterator::all::() + .filter(|t| f(*t)) + .map(TypeConstraint::Scalar) + .collect(); + TypeConstraint::OneOf(matching_types) + } + + pub fn comparable() -> TypeConstraint { + Self::scalar_one_of_by_predicate(BsonScalarType::is_comparable) + } + + pub fn numeric() -> TypeConstraint { + Self::scalar_one_of_by_predicate(BsonScalarType::is_numeric) + } + + pub fn is_numeric(&self) -> bool { + match self { + TypeConstraint::Scalar(scalar_type) => BsonScalarType::is_numeric(*scalar_type), + TypeConstraint::OneOf(types) => types.iter().all(|t| t.is_numeric()), + TypeConstraint::Union(types) => types.iter().all(|t| t.is_numeric()), + _ => false, + } + } + + pub fn union(a: TypeConstraint, b: TypeConstraint) -> Self { + match (a, b) { + (TypeConstraint::Union(mut types_a), TypeConstraint::Union(mut types_b)) => { + types_a.append(&mut types_b); + TypeConstraint::Union(types_a) + } + (TypeConstraint::Union(mut types), b) => { + types.insert(b); + TypeConstraint::Union(types) + } + (a, TypeConstraint::Union(mut types)) => { + types.insert(a); + TypeConstraint::Union(types) + } + (a, b) => TypeConstraint::Union([a, b].into()), } } } @@ -114,7 +226,7 @@ impl From for TypeConstraint { } } ndc_models::Type::Nullable { underlying_type } => { - TypeConstraint::Nullable(Box::new(Self::from(*underlying_type))) + Self::from(*underlying_type).make_nullable() } ndc_models::Type::Array { element_type } => { TypeConstraint::ArrayOf(Box::new(Self::from(*element_type))) @@ -126,14 +238,28 @@ impl From for TypeConstraint { } } -// /// Order constraints by complexity to help with type unification -// impl PartialOrd for TypeConstraint { -// fn partial_cmp(&self, other: &Self) -> Option { -// let a = self.complexity(); -// let b = other.complexity(); -// a.partial_cmp(&b) -// } -// } +impl From for TypeConstraint { + fn from(t: configuration::schema::Type) -> Self { + match t { + configuration::schema::Type::ExtendedJSON => TypeConstraint::ExtendedJSON, + configuration::schema::Type::Scalar(s) => TypeConstraint::Scalar(s), + configuration::schema::Type::Object(name) => TypeConstraint::Object(name.into()), + configuration::schema::Type::ArrayOf(t) => { + TypeConstraint::ArrayOf(Box::new(TypeConstraint::from(*t))) + } + configuration::schema::Type::Nullable(t) => TypeConstraint::from(*t).make_nullable(), + configuration::schema::Type::Predicate { object_type_name } => { + TypeConstraint::Predicate { object_type_name } + } + } + } +} + +impl From<&configuration::schema::Type> for TypeConstraint { + fn from(t: &configuration::schema::Type) -> Self { + t.clone().into() + } +} #[derive(Debug, Clone, PartialEq, Eq)] pub struct ObjectTypeConstraint { diff --git a/crates/cli/src/native_query/type_solver/constraint_to_type.rs b/crates/cli/src/native_query/type_solver/constraint_to_type.rs index a6676384..b38370e9 100644 --- a/crates/cli/src/native_query/type_solver/constraint_to_type.rs +++ b/crates/cli/src/native_query/type_solver/constraint_to_type.rs @@ -4,6 +4,7 @@ use configuration::{ schema::{ObjectField, ObjectType, Type}, Configuration, }; +use itertools::Itertools as _; use ndc_models::{FieldName, ObjectTypeName}; use crate::native_query::{ @@ -51,14 +52,6 @@ pub fn constraint_to_type( .map(|_| Type::Predicate { object_type_name: object_type_name.clone(), }), - C::Nullable(c) => constraint_to_type( - configuration, - solutions, - added_object_types, - object_type_constraints, - c, - )? - .map(|t| Type::Nullable(Box::new(t))), C::Variable(variable) => solutions.get(variable).cloned(), C::ElementOf(c) => constraint_to_type( configuration, @@ -88,6 +81,37 @@ pub fn constraint_to_type( .transpose() }) .transpose()?, + + t @ C::Union(constraints) if t.is_nullable() => { + let non_null_constraints = constraints + .iter() + .filter(|t| *t != &C::null()) + .collect_vec(); + let underlying_constraint = if non_null_constraints.len() == 1 { + non_null_constraints.into_iter().next().unwrap() + } else { + &C::Union(non_null_constraints.into_iter().cloned().collect()) + }; + constraint_to_type( + configuration, + solutions, + added_object_types, + object_type_constraints, + underlying_constraint, + )? + .map(|t| Type::Nullable(Box::new(t))) + } + + C::Union(_) => Some(Type::ExtendedJSON), + + t @ C::OneOf(_) if t.is_numeric() => { + // We know it's a number, but we don't know exactly which numeric type. Double should + // be good enough for anybody, right? + Some(Type::Scalar(mongodb_support::BsonScalarType::Double)) + } + + C::OneOf(_) => Some(Type::ExtendedJSON), + C::WithFieldOverrides { augmented_object_type_name, target_type, diff --git a/crates/cli/src/native_query/type_solver/mod.rs b/crates/cli/src/native_query/type_solver/mod.rs index c4d149af..74897ff0 100644 --- a/crates/cli/src/native_query/type_solver/mod.rs +++ b/crates/cli/src/native_query/type_solver/mod.rs @@ -1,8 +1,7 @@ mod constraint_to_type; mod simplify; -mod substitute; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use configuration::{ schema::{ObjectType, Type}, @@ -11,7 +10,6 @@ use configuration::{ use itertools::Itertools; use ndc_models::ObjectTypeName; use simplify::simplify_constraints; -use substitute::substitute; use super::{ error::{Error, Result}, @@ -24,13 +22,14 @@ pub fn unify( configuration: &Configuration, required_type_variables: &[TypeVariable], object_type_constraints: &mut BTreeMap, - mut type_variables: HashMap>, + type_variables: HashMap>, ) -> Result<( HashMap, BTreeMap, )> { let mut added_object_types = BTreeMap::new(); let mut solutions = HashMap::new(); + let mut substitutions = HashMap::new(); fn is_solved(solutions: &HashMap, variable: TypeVariable) -> bool { solutions.contains_key(&variable) } @@ -38,33 +37,32 @@ pub fn unify( #[cfg(test)] println!("begin unify:\n type_variables: {type_variables:?}\n object_type_constraints: {object_type_constraints:?}\n"); - // TODO: This could be simplified. Instead of mutating constraints using `simplify_constraints` - // we might be able to roll all constraints into one and pass that to `constraint_to_type` in - // one step, but leave the original constraints unchanged if any part of that fails. That could - // make it simpler to keep track of source locations for when we want to report type mismatch - // errors between constraints. loop { let prev_type_variables = type_variables.clone(); let prev_solutions = solutions.clone(); + let prev_substitutions = substitutions.clone(); // TODO: check for mismatches, e.g. constraint list contains scalar & array ENG-1252 - for (variable, constraints) in type_variables.iter_mut() { + for (variable, constraints) in type_variables.iter() { + if is_solved(&solutions, *variable) { + continue; + } + let simplified = simplify_constraints( configuration, + &substitutions, object_type_constraints, - variable.variance, + Some(*variable), constraints.iter().cloned(), - ); - *constraints = simplified; - } - - #[cfg(test)] - println!("simplify:\n type_variables: {type_variables:?}\n object_type_constraints: {object_type_constraints:?}\n"); - - for (variable, constraints) in &type_variables { - if !is_solved(&solutions, *variable) && constraints.len() == 1 { - let constraint = constraints.iter().next().unwrap(); + ) + .map_err(Error::Multiple)?; + #[cfg(test)] + if simplified != *constraints { + println!("simplified {variable}: {constraints:?} -> {simplified:?}"); + } + if simplified.len() == 1 { + let constraint = simplified.iter().next().unwrap(); if let Some(solved_type) = constraint_to_type( configuration, &solutions, @@ -72,25 +70,24 @@ pub fn unify( object_type_constraints, constraint, )? { - solutions.insert(*variable, solved_type); + #[cfg(test)] + println!("solved {variable}: {solved_type:?}"); + solutions.insert(*variable, solved_type.clone()); + substitutions.insert(*variable, [solved_type.into()].into()); } } } #[cfg(test)] - println!("check solutions:\n solutions: {solutions:?}\n added_object_types: {added_object_types:?}\n"); + println!("added_object_types: {added_object_types:?}\n"); let variables = type_variables_by_complexity(&type_variables); - - for variable in &variables { - if let Some(variable_constraints) = type_variables.get(variable).cloned() { - substitute(&mut type_variables, *variable, &variable_constraints); - } + if let Some(v) = variables.iter().find(|v| !substitutions.contains_key(*v)) { + // TODO: We should do some recursion to substitute variable references within + // substituted constraints to existing substitutions. + substitutions.insert(*v, type_variables[v].clone()); } - #[cfg(test)] - println!("substitute: {type_variables:?}\n"); - if required_type_variables .iter() .copied() @@ -99,7 +96,10 @@ pub fn unify( return Ok((solutions, added_object_types)); } - if type_variables == prev_type_variables && solutions == prev_solutions { + if type_variables == prev_type_variables + && solutions == prev_solutions + && substitutions == prev_substitutions + { return Err(Error::FailedToUnify { unsolved_variables: variables .into_iter() @@ -112,7 +112,7 @@ pub fn unify( /// List type variables ordered according to increasing complexity of their constraints. fn type_variables_by_complexity( - type_variables: &HashMap>, + type_variables: &HashMap>, ) -> Vec { type_variables .iter() diff --git a/crates/cli/src/native_query/type_solver/simplify.rs b/crates/cli/src/native_query/type_solver/simplify.rs index a040b6ed..d41d8e0d 100644 --- a/crates/cli/src/native_query/type_solver/simplify.rs +++ b/crates/cli/src/native_query/type_solver/simplify.rs @@ -1,20 +1,18 @@ -#![allow(warnings)] +use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::collections::{BTreeMap, HashSet}; - -use configuration::schema::{ObjectType, Type}; use configuration::Configuration; -use itertools::Itertools; +use itertools::Itertools as _; use mongodb_support::align::try_align; use mongodb_support::BsonScalarType; use ndc_models::{FieldName, ObjectTypeName}; +use nonempty::NonEmpty; use crate::introspection::type_unification::is_supertype; +use crate::native_query::helpers::get_object_field_type; use crate::native_query::type_constraint::Variance; use crate::native_query::{ error::Error, - pipeline_type_context::PipelineTypeContext, type_constraint::{ObjectTypeConstraint, TypeConstraint, TypeVariable}, }; @@ -22,105 +20,178 @@ use TypeConstraint as C; type Simplified = std::result::Result; +struct SimplifyContext<'a> { + configuration: &'a Configuration, + substitutions: &'a HashMap>, + object_type_constraints: &'a mut BTreeMap, + errors: &'a mut Vec, +} + // Attempts to reduce the number of type constraints from the input by combining redundant -// constraints, and by merging constraints into more specific ones where possible. This is -// guaranteed to produce a list that is equal or smaller in length compared to the input. +// constraints, merging constraints into more specific ones where possible, and applying +// accumulated variable substitutions. pub fn simplify_constraints( configuration: &Configuration, + substitutions: &HashMap>, object_type_constraints: &mut BTreeMap, - variance: Variance, + variable: Option, constraints: impl IntoIterator, -) -> HashSet { +) -> Result, Vec> { + let mut errors = vec![]; + let mut context = SimplifyContext { + configuration, + substitutions, + object_type_constraints, + errors: &mut errors, + }; + let constraints = simplify_constraints_internal(&mut context, variable, constraints); + if errors.is_empty() { + Ok(constraints) + } else { + Err(errors) + } +} + +fn simplify_constraints_internal( + context: &mut SimplifyContext, + variable: Option, + constraints: impl IntoIterator, +) -> BTreeSet { + let constraints: BTreeSet<_> = constraints + .into_iter() + .flat_map(|constraint| simplify_single_constraint(context, variable, constraint)) + .collect(); + constraints .into_iter() .coalesce(|constraint_a, constraint_b| { - simplify_constraint_pair( - configuration, - object_type_constraints, - variance, - constraint_a, - constraint_b, - ) + simplify_constraint_pair(context, variable, constraint_a, constraint_b) }) .collect() } +fn simplify_single_constraint( + context: &mut SimplifyContext, + variable: Option, + constraint: TypeConstraint, +) -> Vec { + match constraint { + C::Variable(v) if Some(v) == variable => vec![], + + C::Variable(v) => match context.substitutions.get(&v) { + Some(constraints) => constraints.iter().cloned().collect(), + None => vec![C::Variable(v)], + }, + + C::FieldOf { target_type, path } => { + let object_type = simplify_single_constraint(context, variable, *target_type.clone()); + if object_type.len() == 1 { + let object_type = object_type.into_iter().next().unwrap(); + match expand_field_of(context, object_type, path.clone()) { + Ok(Some(t)) => return t, + Ok(None) => (), + Err(e) => context.errors.push(e), + } + } + vec![C::FieldOf { target_type, path }] + } + + C::Union(constraints) => { + let simplified_constraints = + simplify_constraints_internal(context, variable, constraints); + vec![C::Union(simplified_constraints)] + } + + C::OneOf(constraints) => { + let simplified_constraints = + simplify_constraints_internal(context, variable, constraints); + vec![C::OneOf(simplified_constraints)] + } + + _ => vec![constraint], + } +} + fn simplify_constraint_pair( - configuration: &Configuration, - object_type_constraints: &mut BTreeMap, - variance: Variance, + context: &mut SimplifyContext, + variable: Option, a: TypeConstraint, b: TypeConstraint, ) -> Simplified { + let variance = variable.map(|v| v.variance).unwrap_or(Variance::Invariant); match (a, b) { - (C::ExtendedJSON, _) | (_, C::ExtendedJSON) => Ok(C::ExtendedJSON), // TODO: Do we want this in contravariant case? + (a, b) if a == b => Ok(a), + + (C::Variable(a), C::Variable(b)) if a == b => Ok(C::Variable(a)), + + (C::ExtendedJSON, _) | (_, C::ExtendedJSON) if variance == Variance::Covariant => { + Ok(C::ExtendedJSON) + } + (C::ExtendedJSON, b) if variance == Variance::Contravariant => Ok(b), + (a, C::ExtendedJSON) if variance == Variance::Contravariant => Ok(a), + (C::Scalar(a), C::Scalar(b)) => solve_scalar(variance, a, b), - // TODO: We need to make sure we aren't putting multiple layers of Nullable on constraints - // - if a and b have mismatched levels of Nullable they won't unify - (C::Nullable(a), C::Nullable(b)) => { - simplify_constraint_pair(configuration, object_type_constraints, variance, *a, *b) - .map(|constraint| C::Nullable(Box::new(constraint))) + (C::Union(mut a), C::Union(mut b)) if variance == Variance::Covariant => { + a.append(&mut b); + let union = simplify_constraints_internal(context, variable, a); + Ok(C::Union(union)) } - (C::Nullable(a), b) if variance == Variance::Covariant => { - simplify_constraint_pair(configuration, object_type_constraints, variance, *a, b) - .map(|constraint| C::Nullable(Box::new(constraint))) + + (C::Union(a), C::Union(b)) if variance == Variance::Contravariant => { + let intersection: BTreeSet<_> = a.intersection(&b).cloned().collect(); + if intersection.is_empty() { + Err((C::Union(a), C::Union(b))) + } else if intersection.len() == 1 { + Ok(intersection.into_iter().next().unwrap()) + } else { + Ok(C::Union(intersection)) + } } - (a, b @ C::Nullable(_)) => { - simplify_constraint_pair(configuration, object_type_constraints, variance, b, a) + + (C::Union(mut a), b) if variance == Variance::Covariant => { + a.insert(b); + let union = simplify_constraints_internal(context, variable, a); + Ok(C::Union(union)) } + (b, a @ C::Union(_)) => simplify_constraint_pair(context, variable, b, a), - (C::Variable(a), C::Variable(b)) if a == b => Ok(C::Variable(a)), + (C::OneOf(mut a), C::OneOf(mut b)) => { + a.append(&mut b); + Ok(C::OneOf(a)) + } - // (C::Scalar(_), C::Variable(_)) => todo!(), - // (C::Scalar(_), C::ElementOf(_)) => todo!(), - (C::Scalar(_), C::FieldOf { target_type, path }) => todo!(), - ( - C::Scalar(_), - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - // (C::Object(_), C::Scalar(_)) => todo!(), + (C::OneOf(constraints), b) => { + let matches: BTreeSet<_> = constraints + .clone() + .into_iter() + .filter_map( + |c| match simplify_constraint_pair(context, variable, c, b.clone()) { + Ok(c) => Some(c), + Err(_) => None, + }, + ) + .collect(); + + if matches.len() == 1 { + Ok(matches.into_iter().next().unwrap()) + } else if matches.is_empty() { + // TODO: record type mismatch + Err((C::OneOf(constraints), b)) + } else { + Ok(C::OneOf(matches)) + } + } + (a, b @ C::OneOf(_)) => simplify_constraint_pair(context, variable, b, a), + + (C::Object(a), C::Object(b)) if a == b => Ok(C::Object(a)), (C::Object(a), C::Object(b)) => { - merge_object_type_constraints(configuration, object_type_constraints, variance, a, b) + match merge_object_type_constraints(context, variable, &a, &b) { + Some(merged_name) => Ok(C::Object(merged_name)), + None => Err((C::Object(a), C::Object(b))), + } } - // (C::Object(_), C::ArrayOf(_)) => todo!(), - // (C::Object(_), C::Nullable(_)) => todo!(), - // (C::Object(_), C::Predicate { object_type_name }) => todo!(), - // (C::Object(_), C::Variable(_)) => todo!(), - (C::Object(_), C::ElementOf(_)) => todo!(), - (C::Object(_), C::FieldOf { target_type, path }) => todo!(), - ( - C::Object(_), - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - // (C::ArrayOf(_), C::Scalar(_)) => todo!(), - // (C::ArrayOf(_), C::Object(_)) => todo!(), - // (C::ArrayOf(_), C::ArrayOf(_)) => todo!(), - // (C::ArrayOf(_), C::Nullable(_)) => todo!(), - // (C::ArrayOf(_), C::Predicate { object_type_name }) => todo!(), - // (C::ArrayOf(_), C::Variable(_)) => todo!(), - // (C::ArrayOf(_), C::ElementOf(_)) => todo!(), - (C::ArrayOf(_), C::FieldOf { target_type, path }) => todo!(), - ( - C::ArrayOf(_), - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - (C::Predicate { object_type_name }, C::Scalar(_)) => todo!(), - (C::Predicate { object_type_name }, C::Object(_)) => todo!(), - (C::Predicate { object_type_name }, C::ArrayOf(_)) => todo!(), - (C::Predicate { object_type_name }, C::Nullable(_)) => todo!(), + ( C::Predicate { object_type_name: a, @@ -128,237 +199,159 @@ fn simplify_constraint_pair( C::Predicate { object_type_name: b, }, - ) => todo!(), - (C::Predicate { object_type_name }, C::Variable(_)) => todo!(), - (C::Predicate { object_type_name }, C::ElementOf(_)) => todo!(), - (C::Predicate { object_type_name }, C::FieldOf { target_type, path }) => todo!(), - ( - C::Predicate { object_type_name }, - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - (C::Variable(_), C::Scalar(_)) => todo!(), - (C::Variable(_), C::Object(_)) => todo!(), - (C::Variable(_), C::ArrayOf(_)) => todo!(), - (C::Variable(_), C::Nullable(_)) => todo!(), - (C::Variable(_), C::Predicate { object_type_name }) => todo!(), - (C::Variable(_), C::Variable(_)) => todo!(), - (C::Variable(_), C::ElementOf(_)) => todo!(), - (C::Variable(_), C::FieldOf { target_type, path }) => todo!(), - ( - C::Variable(_), - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - (C::ElementOf(_), C::Scalar(_)) => todo!(), - (C::ElementOf(_), C::Object(_)) => todo!(), - (C::ElementOf(_), C::ArrayOf(_)) => todo!(), - (C::ElementOf(_), C::Nullable(_)) => todo!(), - (C::ElementOf(_), C::Predicate { object_type_name }) => todo!(), - (C::ElementOf(_), C::Variable(_)) => todo!(), - (C::ElementOf(_), C::ElementOf(_)) => todo!(), - (C::ElementOf(_), C::FieldOf { target_type, path }) => todo!(), - ( - C::ElementOf(_), - C::WithFieldOverrides { - target_type, - fields, - .. - }, - ) => todo!(), - (C::FieldOf { target_type, path }, C::Scalar(_)) => todo!(), - (C::FieldOf { target_type, path }, C::Object(_)) => todo!(), - (C::FieldOf { target_type, path }, C::ArrayOf(_)) => todo!(), - (C::FieldOf { target_type, path }, C::Nullable(_)) => todo!(), - (C::FieldOf { target_type, path }, C::Predicate { object_type_name }) => todo!(), - (C::FieldOf { target_type, path }, C::Variable(_)) => todo!(), - (C::FieldOf { target_type, path }, C::ElementOf(_)) => todo!(), + ) if a == b => Ok(C::Predicate { + object_type_name: a, + }), ( - C::FieldOf { - target_type: target_type_a, - path: path_a, + C::Predicate { + object_type_name: a, }, - C::FieldOf { - target_type: target_type_b, - path: path_b, + C::Predicate { + object_type_name: b, }, - ) => todo!(), + ) if a == b => match merge_object_type_constraints(context, variable, &a, &b) { + Some(merged_name) => Ok(C::Predicate { + object_type_name: merged_name, + }), + None => Err(( + C::Predicate { + object_type_name: a, + }, + C::Predicate { + object_type_name: b, + }, + )), + }, + + // TODO: We probably want a separate step that swaps ElementOf and FieldOf constraints with + // constraint of the targeted structure. We might do a similar thing with + // WithFieldOverrides. + + // (C::ElementOf(a), b) => { + // if let TypeConstraint::ArrayOf(elem_type) = *a { + // simplify_constraint_pair( + // configuration, + // object_type_constraints, + // variance, + // *elem_type, + // b, + // ) + // } else { + // Err((C::ElementOf(a), b)) + // } + // } + // + // (C::FieldOf { target_type, path }, b) => { + // if let TypeConstraint::Object(type_name) = *target_type { + // let object_type = object_type_constraints + // } else { + // Err((C::FieldOf { target_type, path }, b)) + // } + // } + // ( - // C::FieldOf { target_type, path }, + // C::Object(_), // C::WithFieldOverrides { // target_type, // fields, // .. // }, // ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::Scalar(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::Object(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::ArrayOf(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::Nullable(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::Predicate { object_type_name }, - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::Variable(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type, - fields, - .. - }, - C::ElementOf(_), - ) => todo!(), - ( - C::WithFieldOverrides { - target_type: target_type_a, - fields, - .. - }, - C::FieldOf { - target_type: target_type_b, - path, - }, - ) => todo!(), - ( - C::WithFieldOverrides { - target_type: target_type_a, - fields: fields_a, - .. - }, - C::WithFieldOverrides { - target_type: target_type_b, - fields: fields_b, - .. - }, - ) => todo!(), - _ => todo!("other simplify branch"), + (C::ArrayOf(a), C::ArrayOf(b)) => { + match simplify_constraint_pair(context, variable, *a, *b) { + Ok(ab) => Ok(C::ArrayOf(Box::new(ab))), + Err((a, b)) => Err((C::ArrayOf(Box::new(a)), C::ArrayOf(Box::new(b)))), + } + } + + (a, b) => Err((a, b)), } } +/// Reconciles two scalar type constraints depending on variance of the context. In a covariant +/// context the type of a type variable is determined to be the supertype of the two (if the types +/// overlap). In a covariant context the variable type is the subtype of the two instead. fn solve_scalar( variance: Variance, a: BsonScalarType, b: BsonScalarType, ) -> Simplified { - if variance == Variance::Contravariant { - return solve_scalar(Variance::Covariant, b, a); - } - - if a == b || is_supertype(&a, &b) { - Ok(C::Scalar(a)) - } else if is_supertype(&b, &a) { - Ok(C::Scalar(b)) - } else { - Err((C::Scalar(a), C::Scalar(b))) + match variance { + Variance::Covariant => { + if a == b || is_supertype(&a, &b) { + Ok(C::Scalar(a)) + } else if is_supertype(&b, &a) { + Ok(C::Scalar(b)) + } else { + Err((C::Scalar(a), C::Scalar(b))) + } + } + Variance::Contravariant => { + if a == b || is_supertype(&a, &b) { + Ok(C::Scalar(b)) + } else if is_supertype(&b, &a) { + Ok(C::Scalar(a)) + } else { + Err((C::Scalar(a), C::Scalar(b))) + } + } + Variance::Invariant => { + if a == b { + Ok(C::Scalar(a)) + } else { + Err((C::Scalar(a), C::Scalar(b))) + } + } } } fn merge_object_type_constraints( - configuration: &Configuration, - object_type_constraints: &mut BTreeMap, - variance: Variance, - name_a: ObjectTypeName, - name_b: ObjectTypeName, -) -> Simplified { + context: &mut SimplifyContext, + variable: Option, + name_a: &ObjectTypeName, + name_b: &ObjectTypeName, +) -> Option { // Pick from the two input names according to sort order to get a deterministic outcome. - let preferred_name = if name_a <= name_b { &name_a } else { &name_b }; - let merged_name = unique_type_name(configuration, object_type_constraints, preferred_name); + let preferred_name = if name_a <= name_b { name_a } else { name_b }; + let merged_name = unique_type_name( + context.configuration, + context.object_type_constraints, + preferred_name, + ); - let a = look_up_object_type_constraint(configuration, object_type_constraints, &name_a); - let b = look_up_object_type_constraint(configuration, object_type_constraints, &name_b); + let a = look_up_object_type_constraint(context, name_a); + let b = look_up_object_type_constraint(context, name_b); let merged_fields_result = try_align( a.fields.clone().into_iter().collect(), b.fields.clone().into_iter().collect(), always_ok(TypeConstraint::make_nullable), always_ok(TypeConstraint::make_nullable), - |field_a, field_b| { - unify_object_field( - configuration, - object_type_constraints, - variance, - field_a, - field_b, - ) - }, + |field_a, field_b| unify_object_field(context, variable, field_a, field_b), ); let fields = match merged_fields_result { Ok(merged_fields) => merged_fields.into_iter().collect(), Err(_) => { - return Err(( - TypeConstraint::Object(name_a), - TypeConstraint::Object(name_b), - )) + return None; } }; let merged_object_type = ObjectTypeConstraint { fields }; - object_type_constraints.insert(merged_name.clone(), merged_object_type); + context + .object_type_constraints + .insert(merged_name.clone(), merged_object_type); - Ok(TypeConstraint::Object(merged_name)) + Some(merged_name) } fn unify_object_field( - configuration: &Configuration, - object_type_constraints: &mut BTreeMap, - variance: Variance, + context: &mut SimplifyContext, + variable: Option, field_type_a: TypeConstraint, field_type_b: TypeConstraint, ) -> Result { - simplify_constraint_pair( - configuration, - object_type_constraints, - variance, - field_type_a, - field_type_b, - ) - .map_err(|_| ()) + simplify_constraint_pair(context, variable, field_type_a, field_type_b).map_err(|_| ()) } fn always_ok(mut f: F) -> impl FnMut(A) -> Result @@ -369,13 +362,12 @@ where } fn look_up_object_type_constraint( - configuration: &Configuration, - object_type_constraints: &BTreeMap, + context: &SimplifyContext, name: &ObjectTypeName, ) -> ObjectTypeConstraint { - if let Some(object_type) = configuration.object_types.get(name) { + if let Some(object_type) = context.configuration.object_types.get(name) { object_type.clone().into() - } else if let Some(object_type) = object_type_constraints.get(name) { + } else if let Some(object_type) = context.object_type_constraints.get(name) { object_type.clone() } else { unreachable!("look_up_object_type_constraint") @@ -397,3 +389,161 @@ fn unique_type_name( } type_name } + +fn expand_field_of( + context: &mut SimplifyContext, + object_type: TypeConstraint, + path: NonEmpty, +) -> Result>, Error> { + let field_type = match object_type { + C::ExtendedJSON => Some(vec![C::ExtendedJSON]), + C::Object(type_name) => get_object_constraint_field_type(context, &type_name, path)?, + C::Union(constraints) => { + let variants: BTreeSet = constraints + .into_iter() + .map(|t| { + let maybe_expanded = expand_field_of(context, t.clone(), path.clone())?; + + // TODO: if variant has more than one element that should be interpreted as an + // intersection, which we haven't implemented yet + Ok(match maybe_expanded { + Some(variant) if variant.len() <= 1 => variant, + _ => vec![t], + }) + }) + .flatten_ok() + .collect::>()?; + Some(vec![(C::Union(variants))]) + } + C::OneOf(constraints) => { + // The difference between the Union and OneOf cases is that in OneOf we want to prune + // variants that don't expand, while in Union we want to preserve unexpanded variants. + let expanded_variants: BTreeSet = constraints + .into_iter() + .map(|t| { + let maybe_expanded = expand_field_of(context, t, path.clone())?; + + // TODO: if variant has more than one element that should be interpreted as an + // intersection, which we haven't implemented yet + Ok(match maybe_expanded { + Some(variant) if variant.len() <= 1 => variant, + _ => vec![], + }) + }) + .flatten_ok() + .collect::>()?; + if expanded_variants.len() == 1 { + Some(vec![expanded_variants.into_iter().next().unwrap()]) + } else if !expanded_variants.is_empty() { + Some(vec![C::Union(expanded_variants)]) + } else { + Err(Error::Other(format!( + "no variant matched object field path {path:?}" + )))? + } + } + _ => None, + }; + Ok(field_type) +} + +fn get_object_constraint_field_type( + context: &mut SimplifyContext, + object_type_name: &ObjectTypeName, + path: NonEmpty, +) -> Result>, Error> { + if let Some(object_type) = context.configuration.object_types.get(object_type_name) { + let t = get_object_field_type( + &context.configuration.object_types, + object_type_name, + object_type, + path, + )?; + return Ok(Some(vec![t.clone().into()])); + } + + let Some(object_type_constraint) = context.object_type_constraints.get(object_type_name) else { + return Err(Error::UnknownObjectType(object_type_name.to_string())); + }; + + let field_name = path.head; + let rest = NonEmpty::from_vec(path.tail); + + let field_type = object_type_constraint + .fields + .get(&field_name) + .ok_or_else(|| Error::ObjectMissingField { + object_type: object_type_name.clone(), + field_name: field_name.clone(), + })? + .clone(); + + let field_type = simplify_single_constraint(context, None, field_type); + + match rest { + None => Ok(Some(field_type)), + Some(rest) if field_type.len() == 1 => match field_type.into_iter().next().unwrap() { + C::Object(type_name) => get_object_constraint_field_type(context, &type_name, rest), + _ => Err(Error::ObjectMissingField { + object_type: object_type_name.clone(), + field_name: field_name.clone(), + }), + }, + _ if field_type.is_empty() => Err(Error::Other( + "could not resolve object field to a type".to_string(), + )), + _ => Ok(None), // field_type len > 1 + } +} + +#[cfg(test)] +mod tests { + use googletest::prelude::*; + use mongodb_support::BsonScalarType; + + use crate::native_query::type_constraint::{TypeConstraint, Variance}; + + #[googletest::test] + fn multiple_identical_scalar_constraints_resolve_one_constraint() { + expect_eq!( + super::solve_scalar( + Variance::Covariant, + BsonScalarType::String, + BsonScalarType::String, + ), + Ok(TypeConstraint::Scalar(BsonScalarType::String)) + ); + expect_eq!( + super::solve_scalar( + Variance::Contravariant, + BsonScalarType::String, + BsonScalarType::String, + ), + Ok(TypeConstraint::Scalar(BsonScalarType::String)) + ); + } + + #[googletest::test] + fn multiple_scalar_constraints_resolve_to_supertype_in_covariant_context() { + expect_eq!( + super::solve_scalar( + Variance::Covariant, + BsonScalarType::Int, + BsonScalarType::Double, + ), + Ok(TypeConstraint::Scalar(BsonScalarType::Double)) + ); + } + + #[googletest::test] + fn multiple_scalar_constraints_resolve_to_subtype_in_contravariant_context() { + expect_eq!( + super::solve_scalar( + Variance::Contravariant, + BsonScalarType::Int, + BsonScalarType::Double, + ), + Ok(TypeConstraint::Scalar(BsonScalarType::Int)) + ); + } +} diff --git a/crates/cli/src/native_query/type_solver/substitute.rs b/crates/cli/src/native_query/type_solver/substitute.rs deleted file mode 100644 index e87e9ecb..00000000 --- a/crates/cli/src/native_query/type_solver/substitute.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use itertools::Either; - -use crate::native_query::type_constraint::{TypeConstraint, TypeVariable}; - -/// Given a type variable that has been reduced to a single type constraint, replace occurrences if -/// the variable in -pub fn substitute( - type_variables: &mut HashMap>, - variable: TypeVariable, - variable_constraints: &HashSet, -) { - for (v, target_constraints) in type_variables.iter_mut() { - if *v == variable { - continue; - } - - // Replace top-level variable references with the list of constraints assigned to the - // variable being substituted. - let mut substituted_constraints: HashSet = target_constraints - .iter() - .cloned() - .flat_map(|target_constraint| match target_constraint { - TypeConstraint::Variable(v) if v == variable => { - Either::Left(variable_constraints.iter().cloned()) - } - t => Either::Right(std::iter::once(t)), - }) - .collect(); - - // Recursively replace variable references inside each constraint. A [TypeConstraint] can - // reference at most one other constraint, so we can only do this if the variable being - // substituted has been reduced to a single constraint. - if variable_constraints.len() == 1 { - let variable_constraint = variable_constraints.iter().next().unwrap(); - substituted_constraints = substituted_constraints - .into_iter() - .map(|target_constraint| { - substitute_in_constraint(variable, variable_constraint, target_constraint) - }) - .collect(); - } - - *target_constraints = substituted_constraints; - } - // substitution_made -} - -fn substitute_in_constraint( - variable: TypeVariable, - variable_constraint: &TypeConstraint, - target_constraint: TypeConstraint, -) -> TypeConstraint { - match target_constraint { - t @ TypeConstraint::Variable(v) => { - if v == variable { - variable_constraint.clone() - } else { - t - } - } - t @ TypeConstraint::ExtendedJSON => t, - t @ TypeConstraint::Scalar(_) => t, - t @ TypeConstraint::Object(_) => t, - TypeConstraint::ArrayOf(t) => TypeConstraint::ArrayOf(Box::new(substitute_in_constraint( - variable, - variable_constraint, - *t, - ))), - TypeConstraint::Nullable(t) => TypeConstraint::Nullable(Box::new( - substitute_in_constraint(variable, variable_constraint, *t), - )), - t @ TypeConstraint::Predicate { .. } => t, - TypeConstraint::ElementOf(t) => TypeConstraint::ElementOf(Box::new( - substitute_in_constraint(variable, variable_constraint, *t), - )), - TypeConstraint::FieldOf { target_type, path } => TypeConstraint::FieldOf { - target_type: Box::new(substitute_in_constraint( - variable, - variable_constraint, - *target_type, - )), - path, - }, - TypeConstraint::WithFieldOverrides { - augmented_object_type_name, - target_type, - fields, - } => TypeConstraint::WithFieldOverrides { - augmented_object_type_name, - target_type: Box::new(substitute_in_constraint( - variable, - variable_constraint, - *target_type, - )), - fields, - }, - } -} diff --git a/crates/mongodb-agent-common/proptest-regressions/query/serialization/tests.txt b/crates/mongodb-agent-common/proptest-regressions/query/serialization/tests.txt index db207898..e85c3bad 100644 --- a/crates/mongodb-agent-common/proptest-regressions/query/serialization/tests.txt +++ b/crates/mongodb-agent-common/proptest-regressions/query/serialization/tests.txt @@ -10,3 +10,4 @@ cc 7d760e540b56fedac7dd58e5bdb5bb9613b9b0bc6a88acfab3fc9c2de8bf026d # shrinks to cc 21360610045c5a616b371fb8d5492eb0c22065d62e54d9c8a8761872e2e192f3 # shrinks to bson = Array([Document({}), Document({" ": Null})]) cc 8842e7f78af24e19847be5d8ee3d47c547ef6c1bb54801d360a131f41a87f4fa cc 2a192b415e5669716701331fe4141383a12ceda9acc9f32e4284cbc2ed6f2d8a # shrinks to bson = Document({"A": Document({"ยก": JavaScriptCodeWithScope { code: "", scope: Document({"\0": Int32(-1)}) }})}), mode = Relaxed +cc 4c37daee6ab1e1bcc75b4089786253f29271d116a1785180560ca431d2b4a651 # shrinks to bson = Document({"0": Document({"A": Array([Int32(0), Decimal128(...)])})}) diff --git a/crates/mongodb-support/src/bson_type.rs b/crates/mongodb-support/src/bson_type.rs index dd1e63ef..2289e534 100644 --- a/crates/mongodb-support/src/bson_type.rs +++ b/crates/mongodb-support/src/bson_type.rs @@ -80,7 +80,20 @@ impl<'de> Deserialize<'de> for BsonType { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence, Serialize, Deserialize, JsonSchema)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + Sequence, + Serialize, + Deserialize, + JsonSchema, +)] #[serde(try_from = "BsonType", rename_all = "camelCase")] pub enum BsonScalarType { // numeric diff --git a/crates/test-helpers/src/configuration.rs b/crates/test-helpers/src/configuration.rs index d125fc6a..fb15fe9b 100644 --- a/crates/test-helpers/src/configuration.rs +++ b/crates/test-helpers/src/configuration.rs @@ -1,5 +1,5 @@ use configuration::Configuration; -use ndc_test_helpers::{collection, named_type, object_type}; +use ndc_test_helpers::{array_of, collection, named_type, object_type}; /// Configuration for a MongoDB database that resembles MongoDB's sample_mflix test data set. pub fn mflix_config() -> Configuration { @@ -23,8 +23,35 @@ pub fn mflix_config() -> Configuration { object_type([ ("_id", named_type("ObjectId")), ("credits", named_type("credits")), + ("genres", array_of(named_type("String"))), + ("imdb", named_type("Imdb")), ("title", named_type("String")), ("year", named_type("Int")), + ("tomatoes", named_type("Tomatoes")), + ]), + ), + ( + "Imdb".into(), + object_type([ + ("rating", named_type("Double")), + ("votes", named_type("Int")), + ("id", named_type("Int")), + ]), + ), + ( + "Tomatoes".into(), + object_type([ + ("critic", named_type("TomatoesCriticViewer")), + ("viewer", named_type("TomatoesCriticViewer")), + ("lastUpdated", named_type("Date")), + ]), + ), + ( + "TomatoesCriticViewer".into(), + object_type([ + ("rating", named_type("Double")), + ("numReviews", named_type("Int")), + ("meter", named_type("Int")), ]), ), ]