diff --git a/activerecord/lib/active_record/querying.rb b/activerecord/lib/active_record/querying.rb index 27d0603f3a9d9..67517c3e045c5 100644 --- a/activerecord/lib/active_record/querying.rb +++ b/activerecord/lib/active_record/querying.rb @@ -17,7 +17,7 @@ module Querying :and, :or, :annotate, :optimizer_hints, :extending, :having, :create_with, :distinct, :references, :none, :unscope, :merge, :except, :only, :count, :average, :minimum, :maximum, :sum, :calculate, - :pluck, :pick, :ids, :async_ids, :strict_loading, :excluding, :without, :with, + :pluck, :pick, :ids, :async_ids, :strict_loading, :excluding, :without, :with, :with_recursive, :async_count, :async_average, :async_minimum, :async_maximum, :async_sum, :async_pluck, :async_pick, ].freeze # :nodoc: delegate(*QUERYING_METHODS, to: :all) diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index 229be6c173492..8e52f673de8a8 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -60,7 +60,7 @@ def exec_explain(&block) :reverse_order, :distinct, :create_with, :skip_query_cache] CLAUSE_METHODS = [:where, :having, :from] - INVALID_METHODS_FOR_DELETE_ALL = [:distinct, :with] + INVALID_METHODS_FOR_DELETE_ALL = [:distinct, :with, :with_recursive] VALUE_METHODS = MULTI_VALUE_METHODS + SINGLE_VALUE_METHODS + CLAUSE_METHODS diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 046dad0653ae2..d20011cefae7c 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -435,6 +435,17 @@ def _select!(*fields) # :nodoc: # # ) # # SELECT * FROM posts # + # You can also pass an array of sub-queries to be joined in a +UNION ALL+. + # + # Post.with(posts_with_tags_or_comments: [Post.where("tags_count > ?", 0), Post.where("comments_count > ?", 0)]) + # # => ActiveRecord::Relation + # # WITH posts_with_tags_or_comments AS ( + # # (SELECT * FROM posts WHERE (tags_count > 0)) + # # UNION ALL + # # (SELECT * FROM posts WHERE (comments_count > 0)) + # # ) + # # SELECT * FROM posts + # # Once you define Common Table Expression you can use custom +FROM+ value or +JOIN+ to reference it. # # Post.with(posts_with_tags: Post.where("tags_count > ?", 0)).from("posts_with_tags AS posts") @@ -475,7 +486,12 @@ def _select!(*fields) # :nodoc: def with(*args) raise ArgumentError, "ActiveRecord::Relation#with does not accept a block" if block_given? check_if_method_has_arguments!(__callee__, args) - spawn.with!(*args) + + if args.empty? + WithChain.new(spawn) + else + spawn.with!(*args) + end end # Like #with, but modifies relation in place. @@ -484,6 +500,30 @@ def with!(*args) # :nodoc: self end + # Add a recursive Common Table Expression (CTE) that you can then reference within another SELECT statement. + # + # Post.with_recursive(post_and_replies: [Post.where(id: 42), Post.joins('JOIN post_and_replies ON posts.in_reply_to_id = post_and_replies.id')]) + # # => ActiveRecord::Relation + # # WITH post_and_replies AS ( + # # (SELECT * FROM posts WHERE id = 42) + # # UNION ALL + # # (SELECT * FROM posts JOIN posts_and_replies ON posts.in_reply_to_id = posts_and_replies.id) + # # ) + # # SELECT * FROM posts + # + # See `#with` for more information. + def with_recursive(*args) + check_if_method_has_arguments!(__callee__, args) + spawn.with_recursive!(*args) + end + + # Like #with_recursive but modifies the relation in place. + def with_recursive!(*args) # :nodoc: + self.with_values += args + @with_is_recursive = true + self + end + # Allows you to change a previously set select statement. # # Post.select(:title, :body) @@ -1846,20 +1886,23 @@ def build_with(arel) build_with_value_from_hash(with_value) end - arel.with(with_statements) + @with_is_recursive ? arel.with(:recursive, with_statements) : arel.with(with_statements) end def build_with_value_from_hash(hash) hash.map do |name, value| - expression = - case value - when Arel::Nodes::SqlLiteral then Arel::Nodes::Grouping.new(value) - when ActiveRecord::Relation then value.arel - when Arel::SelectManager then value - else - raise ArgumentError, "Unsupported argument type: `#{value}` #{value.class}" - end - Arel::Nodes::TableAlias.new(expression, name) + Arel::Nodes::TableAlias.new(build_with_expression_from_value(value), name) + end + end + + def build_with_expression_from_value(value) + case value + when Arel::Nodes::SqlLiteral then Arel::Nodes::Grouping.new(value) + when ActiveRecord::Relation then value.arel + when Arel::SelectManager then value + when Array then value.map { |q| build_with_expression_from_value(q) }.reduce { |result, value| result.union(:all, value) } + else + raise ArgumentError, "Unsupported argument type: `#{value}` #{value.class}" end end diff --git a/activerecord/test/cases/relation/with_test.rb b/activerecord/test/cases/relation/with_test.rb index a34dd5dbc0f08..d67f76c8d2697 100644 --- a/activerecord/test/cases/relation/with_test.rb +++ b/activerecord/test/cases/relation/with_test.rb @@ -3,11 +3,11 @@ require "cases/helper" require "models/comment" require "models/post" +require "models/company" module ActiveRecord class WithTest < ActiveRecord::TestCase - fixtures :comments - fixtures :posts + fixtures :comments, :posts, :companies POSTS_WITH_TAGS = [1, 2, 7, 8, 9, 10, 11].freeze POSTS_WITH_COMMENTS = [1, 2, 4, 5, 7].freeze @@ -57,9 +57,35 @@ def test_with_when_called_from_active_record_scope end def test_with_when_invalid_params_are_passed - assert_raise(ArgumentError) { Post.with } assert_raise(ArgumentError) { Post.with(posts_with_tags: nil).load } - assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0")]).load } + assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0"), 5]).load } + end + + def test_with_when_passing_arrays + relation = Post + .with(posts_with_tags_or_comments: [ + Post.where("tags_count > 0"), + Post.where("legacy_comments_count > 0") + ]) + .from("posts_with_tags_or_comments AS posts") + + assert_equal (POSTS_WITH_TAGS + POSTS_WITH_COMMENTS).sort, relation.order(:id).pluck(:id) + end + + def test_with_recursive + top_companies = Company.where(firm_id: nil).to_a + child_companies = Company.where(firm_id: top_companies).to_a + top_companies_and_children = (top_companies.map(&:id) + child_companies.map(&:id)).sort + + relation = Company.with_recursive( + top_companies_and_children: [ + Company.where(firm_id: nil), + Company.joins('JOIN top_companies_and_children ON companies.firm_id = top_companies_and_children.id'), + ] + ).from("top_companies_and_children AS companies") + + assert_equal top_companies_and_children, relation.order(:id).pluck(:id) + assert_match "WITH RECURSIVE", relation.to_sql end def test_with_joins