Skip to content

Commit

Permalink
if: preserve narrowing across branches
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm committed Jan 28, 2025
1 parent 47b1deb commit 90293d0
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 6 deletions.
87 changes: 87 additions & 0 deletions spec/lang/statement/if_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,91 @@ describe("if", function()
]], {
{ msg = "types are not comparable for equality: boolean and integer" }
}))

it("performs type narrowing/widening on all branches", util.check([[
local interface Another
where self.another
another: string
end
local interface Type
where self.t
t: string
end
local record SpecificType is Type
where self.t == "specific"
end
local function needs_type(t: Type): Type
print(t.t)
end
local it: Type | Another
local a: Another = { another = "yes" }
if it is Type then
if it is SpecificType then
it = a -- this does not impact the other branch
else
it = needs_type(it) -- preserves narrowing here
end
end
it = a -- widen back to the union
]]))

it("performs type narrowing/widening on all branches (with constrained generic)", util.check([[
local interface Another
where self.another
another: string
end
local interface Type
where self.t
t: string
end
local record SpecificType is Type
where self.t == "specific"
end
-- function with constrained generic
local function needs_type<T is Type>(t: T): T
print(t.t)
end
local it: Type | Another
local a: Another = { another = "yes" }
if it is Type then
if it is SpecificType then
it = a -- this does not impact the other branch
else
it = needs_type(it) -- preserves narrowing here
end
end
it = a -- widen back to the union
]]))

it("knows when to discard narrowing", util.check([[
local type Key = string | number | boolean
local function f(k: Key, n: number)
if not k then
k = n
if not k then
k = true
end
end
end
]]))

end)
30 changes: 27 additions & 3 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2201,6 +2201,7 @@ local Node = { ExpectedContext = {} }






local show_type
Expand Down Expand Up @@ -8250,6 +8251,25 @@ do
return widened
end

function TypeChecker:collect_if_widens(widens)
local st = self.st
local scope = st[#st]
if scope.widens then
widens = widens or {}
for k, _ in pairs(scope.widens) do
widens[k] = true
end
scope.widens = nil
end
return widens
end

function TypeChecker:widen_all(widens)
for name, _ in pairs(widens) do
self:widen_back_var(name)
end
end

function TypeChecker:begin_scope(node)
table.insert(self.st, { vars = {} })

Expand Down Expand Up @@ -8296,9 +8316,7 @@ do
table.remove(st)

if scope.widens then
for name, _ in pairs(scope.widens) do
self:widen_back_var(name)
end
self:widen_all(scope.widens)
end

if self.collector and node then
Expand Down Expand Up @@ -11980,6 +11998,10 @@ self:expand_type(node, values, elements) })
},
["if"] = {
after = function(self, node, _children)
if node.if_widens then
self:widen_all(node.if_widens)
end

local all_return = true
for _, b in ipairs(node.if_blocks) do
if not b.block_returns then
Expand Down Expand Up @@ -12011,6 +12033,8 @@ self:expand_type(node, values, elements) })
end
end,
after = function(self, node, _children)
node.if_parent.if_widens = self:collect_if_widens(node.if_parent.if_widens)

self:end_scope(node)

if #node.body > 0 and node.body[#node.body].block_returns then
Expand Down
30 changes: 27 additions & 3 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,7 @@ local record Node
if_parent: Node
if_block_n: integer
if_blocks: {Node}
if_widens: {string: boolean}
block_returns: boolean

-- fornum
Expand Down Expand Up @@ -8250,6 +8251,25 @@ do
return widened
end

function TypeChecker:collect_if_widens(widens: {string:boolean}): {string:boolean}
local st = self.st
local scope = st[#st]
if scope.widens then
widens = widens or {}
for k, _ in pairs(scope.widens) do
widens[k] = true
end
scope.widens = nil
end
return widens
end

function TypeChecker:widen_all(widens: {string:boolean})
for name, _ in pairs(widens) do
self:widen_back_var(name)
end
end

function TypeChecker:begin_scope(node?: Node)
table.insert(self.st, { vars = {} })

Expand Down Expand Up @@ -8296,9 +8316,7 @@ do
table.remove(st)

if scope.widens then
for name, _ in pairs(scope.widens) do
self:widen_back_var(name)
end
self:widen_all(scope.widens)
end

if self.collector and node then
Expand Down Expand Up @@ -11980,6 +11998,10 @@ do
},
["if"] = {
after = function(self: TypeChecker, node: Node, _children: {Type}): Type
if node.if_widens then
self:widen_all(node.if_widens)
end

local all_return = true
for _, b in ipairs(node.if_blocks) do
if not b.block_returns then
Expand Down Expand Up @@ -12011,6 +12033,8 @@ do
end
end,
after = function(self: TypeChecker, node: Node, _children: {Type}): Type
node.if_parent.if_widens = self:collect_if_widens(node.if_parent.if_widens)

self:end_scope(node)

if #node.body > 0 and node.body[#node.body].block_returns then
Expand Down

0 comments on commit 90293d0

Please sign in to comment.