p"] = { api.drag_pair_forwards, "Drag element pairs right" },
+ ["f"] = { paredit.api.drag_form_forwards, "Drag form right" },
["
+## Pairwise Dragging
+
+Nvim-paredit has support for dragging elements pairwise. If an element being dragged is within a form that contains
+pairs of elements (such as a clojure `map`) then the element will be dragged along with it's pair.
+
+For example:
+
+```clojure
+{:a 1
+ |:b 2}
+;; Drag backwards
+{|:b 2
+ :a 1}
+```
+
+This is enabled by default and can be disabled by setting `dragging.auto_drag_pairs = false`.
+
+Pairwise dragging works using treesitter queries to identify element pairs within some localized node. This means you
+can very easily extend the paredit pairwise implementation by simply adding new treesitter queries to your nvim
+configuration.
+
+You might want to extend if:
+
+1) You are a language extension author and want to add pairwise dragging support to your extension.
+2) You want to add support for some syntax not supported by nvim-paredit.
+
+This is especially useful if you have your own clojure macros that you want to enable pairwise dragging on.
+
+All you need to do to extend is to add a new file called `queries//paredit/pairwise.scm` in your nvim config
+directory. Make sure to include the `;; extends` directive to the file or you will overwrite any pre-existing queries
+defined by nvim-paredit or other language extensions.
+
+As an example if you want to add support for the following clojure macro:
+
+```clojure
+(defmacro my-custom-bindings [bindings & body]
+ ...)
+
+(my-custom-bindings [a 1
+ b 2]
+ (println a b))
+```
+
+You can add the following TS query
+
+```scm
+;; extends
+
+(list_lit
+ (sym_lit) @fn-name
+ (vec_lit
+ (_) @pair)
+ (#eq? @fn-name "my-custom-bindings"))
+```
+
## Language Support
As this is built using Treesitter it requires that you have the relevant Treesitter grammar installed for your language
@@ -332,6 +397,8 @@ paredit.api.slurp_forwards()
- **`barf_backwards`**
- **`drag_element_forwards`**
- **`drag_element_backwards`**
+- **`drag_pair_forwards`**
+- **`drag_pair_backwards`**
- **`drag_form_forwards`**
- **`drag_form_backwards`**
- **`raise_element`**
diff --git a/lua/nvim-paredit/api/dragging.lua b/lua/nvim-paredit/api/dragging.lua
index 39d1ea3..1685edb 100644
--- a/lua/nvim-paredit/api/dragging.lua
+++ b/lua/nvim-paredit/api/dragging.lua
@@ -1,5 +1,8 @@
local traversal = require("nvim-paredit.utils.traversal")
+local common = require("nvim-paredit.utils.common")
+local ts_utils = require("nvim-paredit.utils.ts")
local ts = require("nvim-treesitter.ts_utils")
+local config = require("nvim-paredit.config")
local langs = require("nvim-paredit.lang")
local M = {}
@@ -44,24 +47,84 @@ function M.drag_form_backwards()
ts.swap_nodes(root, sibling, buf, true)
end
-function M.drag_element_forwards()
- local lang = langs.get_language_api()
- local current_node = lang.get_node_root(ts.get_node_at_cursor())
+local function find_current_pair(pairs, current_node)
+ for i, pair in ipairs(pairs) do
+ for _, node in ipairs(pair) do
+ if node:equal(current_node) then
+ return i, pair
+ end
+ end
+ end
+end
- local sibling = current_node:next_named_sibling()
- if not sibling then
+local function drag_node_in_pair(current_node, nodes, opts)
+ local direction = 1
+ if opts.reversed then
+ direction = -1
+ end
+
+ local pairs = common.chunk_table(nodes, 2)
+ local chunk_index, pair = find_current_pair(pairs, current_node)
+
+ local corresponding_pair = pairs[chunk_index + direction]
+ if not corresponding_pair then
return
end
local buf = vim.api.nvim_get_current_buf()
- ts.swap_nodes(current_node, sibling, buf, true)
+ if pair[2] and corresponding_pair[2] then
+ ts.swap_nodes(pair[2], corresponding_pair[2], buf, true)
+ end
+ if pair[1] and corresponding_pair[1] then
+ ts.swap_nodes(pair[1], corresponding_pair[1], buf, true)
+ end
+end
+
+local function drag_pair(opts)
+ local lang = langs.get_language_api()
+ local current_node = lang.get_node_root(ts.get_node_at_cursor())
+ if not current_node then
+ return
+ end
+
+ local pairwise_nodes = ts_utils.find_pairwise_nodes(
+ current_node,
+ vim.tbl_deep_extend("force", opts, {
+ lang = lang,
+ })
+ )
+ if not pairwise_nodes then
+ local parent = current_node:parent()
+ if not parent then
+ return
+ end
+
+ pairwise_nodes = traversal.get_children_ignoring_comments(parent, {
+ lang = lang,
+ })
+ end
+
+ drag_node_in_pair(current_node, pairwise_nodes, opts)
end
-function M.drag_element_backwards()
+local function drag_element(opts)
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())
- local sibling = current_node:prev_named_sibling()
+ if opts.dragging.auto_drag_pairs then
+ local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, { lang = lang })
+ if pairwise_nodes then
+ return drag_node_in_pair(current_node, pairwise_nodes, opts)
+ end
+ end
+
+ local sibling
+ if opts.reversed then
+ sibling = current_node:prev_named_sibling()
+ else
+ sibling = current_node:next_named_sibling()
+ end
+
if not sibling then
return
end
@@ -70,4 +133,44 @@ function M.drag_element_backwards()
ts.swap_nodes(current_node, sibling, buf, true)
end
+function M.drag_element_forwards(opts)
+ local drag_opts = vim.tbl_deep_extend(
+ "force",
+ {
+ dragging = config.config.dragging or {},
+ },
+ opts or {},
+ {
+ reversed = false,
+ }
+ )
+ drag_element(drag_opts)
+end
+
+function M.drag_element_backwards(opts)
+ local drag_opts = vim.tbl_deep_extend(
+ "force",
+ {
+ dragging = config.config.dragging or {},
+ },
+ opts or {},
+ {
+ reversed = true,
+ }
+ )
+ drag_element(drag_opts)
+end
+
+function M.drag_pair_forwards()
+ drag_pair({
+ reversed = false,
+ })
+end
+
+function M.drag_pair_backwards()
+ drag_pair({
+ reversed = true,
+ })
+end
+
return M
diff --git a/lua/nvim-paredit/api/init.lua b/lua/nvim-paredit/api/init.lua
index df573fa..efac43e 100644
--- a/lua/nvim-paredit/api/init.lua
+++ b/lua/nvim-paredit/api/init.lua
@@ -15,6 +15,10 @@ local M = {
drag_element_forwards = dragging.drag_element_forwards,
drag_element_backwards = dragging.drag_element_backwards,
+
+ drag_pair_forwards = dragging.drag_pair_forwards,
+ drag_pair_backwards = dragging.drag_pair_backwards,
+
drag_form_forwards = dragging.drag_form_forwards,
drag_form_backwards = dragging.drag_form_backwards,
diff --git a/lua/nvim-paredit/api/selections.lua b/lua/nvim-paredit/api/selections.lua
index 1531801..3f9f058 100644
--- a/lua/nvim-paredit/api/selections.lua
+++ b/lua/nvim-paredit/api/selections.lua
@@ -41,7 +41,7 @@ function M.get_range_around_form()
end
function M.get_range_around_top_level_form()
- return get_range_around_form_impl(traversal.get_top_level_node_below_document)
+ return get_range_around_form_impl(traversal.find_local_root)
end
local function select_around_form_impl(range)
@@ -93,7 +93,7 @@ function M.get_range_in_form()
end
function M.get_range_in_top_level_form()
- return get_range_in_form_impl(traversal.get_top_level_node_below_document)
+ return get_range_in_form_impl(traversal.find_local_root)
end
local function select_in_form_impl(range)
diff --git a/lua/nvim-paredit/defaults.lua b/lua/nvim-paredit/defaults.lua
index 03eb727..cf8caa1 100644
--- a/lua/nvim-paredit/defaults.lua
+++ b/lua/nvim-paredit/defaults.lua
@@ -4,7 +4,7 @@ local unwrap = require("nvim-paredit.api.unwrap")
local M = {}
M.default_keys = {
- ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", },
+ ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" },
[">)"] = { api.slurp_forwards, "Slurp forwards" },
[">("] = { api.barf_backwards, "Barf backwards" },
@@ -15,6 +15,9 @@ M.default_keys = {
[">e"] = { api.drag_element_forwards, "Drag element right" },
["p"] = { api.drag_pair_forwards, "Drag element pairs right" },
+ ["f"] = { api.drag_form_forwards, "Drag form right" },
["
function M.get_language_api()
for l in string.gmatch(vim.bo.filetype, "[^.]+") do
if langs[l] ~= nil then
return langs[l]
end
end
- return nil
+ error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR)
end
function M.add_language_extension(filetype, api)
diff --git a/lua/nvim-paredit/utils/common.lua b/lua/nvim-paredit/utils/common.lua
index 93b7c22..68c9825 100644
--- a/lua/nvim-paredit/utils/common.lua
+++ b/lua/nvim-paredit/utils/common.lua
@@ -9,6 +9,20 @@ function M.included_in_table(table, item)
return false
end
+function M.chunk_table(tbl, chunk_size)
+ local result = {}
+ for i = 1, #tbl, chunk_size do
+ local chunk = {}
+ for j = 0, chunk_size - 1 do
+ if tbl[i + j] then
+ table.insert(chunk, tbl[i + j])
+ end
+ end
+ table.insert(result, chunk)
+ end
+ return result
+end
+
-- Compares the two given { col, row } position tuples and returns -1/0/1 depending
-- on whether `a` is less than, equal to or greater than `b`
--
diff --git a/lua/nvim-paredit/utils/traversal.lua b/lua/nvim-paredit/utils/traversal.lua
index b8857ed..72d89d7 100644
--- a/lua/nvim-paredit/utils/traversal.lua
+++ b/lua/nvim-paredit/utils/traversal.lua
@@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts)
end
end
+function M.get_children_ignoring_comments(node, opts)
+ local children = {}
+
+ local index = 0
+ local child = node:named_child(index)
+ while child do
+ if not child:extra() and not opts.lang.node_is_comment(child) then
+ table.insert(children, child)
+ end
+ index = index + 1
+ child = node:named_child(index)
+ end
+
+ return children
+end
+
local function get_child_ignoring_comments(node, index, opts)
if index < 0 or index >= node:named_child_count() then
return
@@ -139,34 +155,18 @@ function M.find_root_element_relative_to(root, child)
return M.find_root_element_relative_to(root, parent)
end
-function M.get_top_level_node_below_document(node)
- -- Document
- -- - Branch A
- -- -- Node X
- -- --- Sub-node 1
- -- - Branch B
- -- -- Node Y
- -- --- Sub-node 2
- -- --- Sub-node 3
-
- -- If we call this function on "Sub-node 2" we expect "Branch B" to be
- -- returned, the top level one below the document itself. We know which
- -- node is the document because it lacks a parent, just like Batman.
-
- local parent = node:parent()
-
- -- Does the node have a parent? If so, we might be at the right level.
- -- If not, we should just return the node right away, we're already too high.
- if parent then
- -- If the parent _also_ has a parent then we still need to go higher, recur.
- if parent:parent() then
- return M.get_top_level_node_below_document(parent)
+-- Find the root node of the tree `node` is a member of, excluding the root
+-- 'source' document.
+function M.find_local_root(node)
+ local current = node
+ while true do
+ local next = current:parent()
+ if not next or next:type() == "source" then
+ break
end
+ current = next
end
-
- -- As soon as we don't have a grandparent or parent, return the node
- -- we're on because it means we're one step below the top level document node.
- return node
+ return current
end
return M
diff --git a/lua/nvim-paredit/utils/ts.lua b/lua/nvim-paredit/utils/ts.lua
new file mode 100644
index 0000000..606e4ad
--- /dev/null
+++ b/lua/nvim-paredit/utils/ts.lua
@@ -0,0 +1,42 @@
+local traversal = require("nvim-paredit.utils.traversal")
+
+local M = {}
+
+-- Use a 'paredit/pairwise' treesitter query to find all nodes within a local
+-- branch that are labeled as @pair.
+--
+-- If any of these labeled nodes match the given target node then return all
+-- matched nodes.
+function M.find_pairwise_nodes(target_node, opts)
+ local root_node = traversal.find_local_root(target_node)
+
+ local bufnr = vim.api.nvim_get_current_buf()
+ local lang = vim.treesitter.language.get_lang(vim.bo.filetype)
+
+ local query = vim.treesitter.query.get(lang, "paredit/pairwise")
+ if not query then
+ return
+ end
+
+ local captures = query:iter_captures(root_node, bufnr)
+ local pairwise_nodes = {}
+ local found = false
+ for id, node in captures do
+ if query.captures[id] == "pair" then
+ if not node:extra() and not opts.lang.node_is_comment(node) then
+ table.insert(pairwise_nodes, node)
+ if node:equal(target_node) then
+ found = true
+ end
+ end
+ end
+ end
+
+ if not found then
+ return
+ end
+
+ return pairwise_nodes
+end
+
+return M
diff --git a/queries/clojure/paredit/pairwise.scm b/queries/clojure/paredit/pairwise.scm
new file mode 100644
index 0000000..78a71a2
--- /dev/null
+++ b/queries/clojure/paredit/pairwise.scm
@@ -0,0 +1,26 @@
+(list_lit
+ (sym_lit) @fn-name
+ (vec_lit
+ (_) @pair)
+ (#any-of? @fn-name "let" "loop" "binding" "with-open" "with-redefs"))
+
+(map_lit
+ (_) @pair)
+
+(list_lit
+ (sym_lit) @fn-name
+ (_)
+ (_) @pair
+ (#eq? @fn-name "case"))
+
+(list_lit
+ (sym_lit) @fn-name
+ (_) @pair
+ (#eq? @fn-name "cond"))
+
+(list_lit
+ (sym_lit) @fn-name
+ (_)
+ (_)
+ (_) @pair
+ (#eq? @fn-name "condp"))
diff --git a/tests/nvim-paredit/pair_drag_spec.lua b/tests/nvim-paredit/pair_drag_spec.lua
new file mode 100644
index 0000000..139f7df
--- /dev/null
+++ b/tests/nvim-paredit/pair_drag_spec.lua
@@ -0,0 +1,86 @@
+local paredit = require("nvim-paredit.api")
+
+local prepare_buffer = require("tests.nvim-paredit.utils").prepare_buffer
+local expect_all = require("tests.nvim-paredit.utils").expect_all
+local expect = require("tests.nvim-paredit.utils").expect
+
+describe("paired-element-auto-dragging", function()
+ vim.api.nvim_buf_set_option(0, "filetype", "clojure")
+ it("should drag map pairs forward", function()
+ prepare_buffer({
+ content = "{:a 1 :b 2}",
+ cursor = { 1, 1 },
+ })
+
+ paredit.drag_element_forwards({
+ dragging = {
+ auto_drag_pairs = true,
+ },
+ })
+ expect({
+ content = "{:b 2 :a 1}",
+ cursor = { 1, 6 },
+ })
+ end)
+
+ it("should drag map pairs backwards", function()
+ prepare_buffer({
+ content = "{:a 1 :b 2}",
+ cursor = { 1, 9 },
+ })
+
+ paredit.drag_element_backwards({
+ dragging = {
+ auto_drag_pairs = true,
+ },
+ })
+ expect({
+ content = "{:b 2 :a 1}",
+ cursor = { 1, 1 },
+ })
+ end)
+
+ it("should detect various types", function()
+ expect_all(function()
+ paredit.drag_element_forwards({ dragging = { auto_drag_pairs = true } })
+ end, {
+ {
+ "let binding",
+ before_content = "(let [a b c d])",
+ before_cursor = { 1, 6 },
+ after_content = "(let [c d a b])",
+ after_cursor = { 1, 10 },
+ },
+ {
+ "loop binding",
+ before_content = "(loop [a b c d])",
+ before_cursor = { 1, 7 },
+ after_content = "(loop [c d a b])",
+ after_cursor = { 1, 11 },
+ },
+ {
+ "case",
+ before_content = "(case a :a 1 :b 2)",
+ before_cursor = { 1, 8 },
+ after_content = "(case a :b 2 :a 1)",
+ after_cursor = { 1, 13 },
+ },
+ })
+ end)
+end)
+
+describe("paired-element-dragging", function()
+ vim.api.nvim_buf_set_option(0, "filetype", "clojure")
+ it("should drag vector elements forwards", function()
+ prepare_buffer({
+ content = "'[a b c d]",
+ cursor = { 1, 2 },
+ })
+
+ paredit.drag_pair_forwards()
+ expect({
+ content = "'[c d a b]",
+ cursor = { 1, 6 },
+ })
+ end)
+end)