diff --git a/lua/telescope/sorters.lua b/lua/telescope/sorters.lua index 6817c97..c1b4b0f 100644 --- a/lua/telescope/sorters.lua +++ b/lua/telescope/sorters.lua @@ -511,44 +511,52 @@ sorters.get_levenshtein_sorter = function() } end -local substr_highlighter = function(_, prompt, display) - local highlights = {} - display = display:lower() +local substr_highlighter = function(make_display) + return function(_, prompt, display) + local highlights = {} + display = make_display(prompt, display) - local search_terms = util.max_split(prompt, "%s") - local hl_start, hl_end + local search_terms = util.max_split(prompt, "%s") + local hl_start, hl_end - for _, word in pairs(search_terms) do - hl_start, hl_end = display:find(word, 1, true) - if hl_start then - table.insert(highlights, { start = hl_start, finish = hl_end }) + for _, word in pairs(search_terms) do + hl_start, hl_end = display:find(word, 1, true) + if hl_start then + table.insert(highlights, { start = hl_start, finish = hl_end }) + end end - end - return highlights + return highlights + end end sorters.get_substr_matcher = function() + local make_display = vim.o.smartcase + and function(prompt, display) + local has_upper_case = not not prompt:match "%u" + return has_upper_case and display or display:lower() + end + or function(_, display) + return display:lower() + end + return Sorter:new { - highlighter = substr_highlighter, + highlighter = substr_highlighter(make_display), scoring_function = function(_, prompt, _, entry) if #prompt == 0 then return 1 end - local display = entry.ordinal:lower() + local display = make_display(prompt, entry.ordinal) local search_terms = util.max_split(prompt, "%s") - local matched = 0 - local total_search_terms = 0 for _, word in pairs(search_terms) do - total_search_terms = total_search_terms + 1 - if display:find(word, 1, true) then - matched = matched + 1 + if not display:find(word, 1, true) then + return -1 end end - return matched == total_search_terms and entry.index or -1 + return entry.index end, } end diff --git a/lua/tests/automated/sorters_spec.lua b/lua/tests/automated/sorters_spec.lua new file mode 100644 index 0000000..f3453a2 --- /dev/null +++ b/lua/tests/automated/sorters_spec.lua @@ -0,0 +1,84 @@ +local sorters = require "telescope.sorters" + +describe("get_substr_matcher", function() + local function with_smartcase(smartcase, case) + local original = vim.o.smartcase + vim.o.smartcase = smartcase + + describe("scoring_function", function() + it(case.msg, function() + local matcher = sorters.get_substr_matcher() + assert.are.same(case.expected_score, matcher.scoring_function(_, case.prompt, _, case.entry)) + end) + end) + + describe("highlighter", function() + it("returns valid highlights", function() + local matcher = sorters.get_substr_matcher() + local highlights = matcher.highlighter(_, case.prompt, case.entry.ordinal) + table.sort(highlights, function(a, b) + return a.start < b.start + end) + assert.are.same(case.expected_highlights, highlights) + end) + end) + + vim.o.smartcase = original + end + + describe("when smartcase=OFF", function() + for _, case in ipairs { + { + msg = "doesn't match", + prompt = "abc def", + entry = { index = 3, ordinal = "abc d" }, + expected_score = -1, + expected_highlights = { { start = 1, finish = 3 } }, + }, + { + msg = "matches with lower case letters only", + prompt = "abc def", + entry = { index = 3, ordinal = "abc def ghi" }, + expected_score = 3, + expected_highlights = { { start = 1, finish = 3 }, { start = 5, finish = 7 } }, + }, + { + msg = "doesn't match with upper case letters", + prompt = "ABC def", + entry = { index = 3, ordinal = "ABC def ghi" }, + expected_score = -1, + expected_highlights = { { start = 5, finish = 7 } }, + }, + } do + with_smartcase(false, case) + end + end) + + describe("when smartcase=OFF", function() + for _, case in ipairs { + { + msg = "doesn't match", + prompt = "abc def", + entry = { index = 3, ordinal = "abc d" }, + expected_score = -1, + expected_highlights = { { start = 1, finish = 3 } }, + }, + { + msg = "matches with lower case letters only", + prompt = "abc def", + entry = { index = 3, ordinal = "abc def ghi" }, + expected_score = 3, + expected_highlights = { { start = 1, finish = 3 }, { start = 5, finish = 7 } }, + }, + { + msg = "matches with upper case letters", + prompt = "ABC def", + entry = { index = 3, ordinal = "ABC def ghi" }, + expected_score = 3, + expected_highlights = { { start = 1, finish = 3 }, { start = 5, finish = 7 } }, + }, + } do + with_smartcase(true, case) + end + end) +end)