Skip to content

Commit dabed5d

Browse files
authored
fix: handle comprehension shadowing when parsing required references for variables (#7446)
This change fixes a bug in our caching mechanism's parsing of references required for each variable. List comprehension was not catching shadowed references. This showed up by causing an issue in threaded caching.
1 parent c33759e commit dabed5d

File tree

4 files changed

+150
-2
lines changed

4 files changed

+150
-2
lines changed

marimo/_ast/visitor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,12 +1006,20 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
10061006

10071007
# Handle refs on the block scope level, or capture cell level
10081008
# references.
1009+
# Only add to ref_stack if the variable is not defined in any
1010+
# non-module ancestor block. This prevents function parameters from
1011+
# being incorrectly marked as refs when used in nested scopes like
1012+
# list comprehensions, while still capturing module-level references.
10091013
if (
10101014
isinstance(node.ctx, ast.Load)
10111015
and self._is_defined(node.id)
10121016
and node.id not in self.ref_stack[-1]
10131017
and (
1014-
node.id not in self.block_stack[-1].defs
1018+
# Check blocks[1:] - skip module block so module-level vars
1019+
# are still tracked as refs, but function params aren't
1020+
not any(
1021+
node.id in block.defs for block in self.block_stack[1:]
1022+
)
10151023
or len(self.block_stack) == 1
10161024
)
10171025
):

marimo/_save/save.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
424424
try:
425425
if attempt.hit:
426426
attempt.restore(scope)
427-
return attempt.meta["return"]
427+
return attempt.meta.get("return")
428428

429429
start_time = time.time()
430430
response = self.__wrapped__(*args, **kwargs)

tests/_ast/test_visitor.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,96 @@ def test_nested_comprehension_generator_with_named_expr() -> None:
305305
assert v.variable_data == {"x": [VariableData(kind="variable")]}
306306

307307

308+
def test_function_param_in_comprehension_not_required_ref() -> None:
309+
"""Function parameters used in list comprehensions should not be required_refs.
310+
311+
Regression test: The parameter `extension` was incorrectly added to required_refs
312+
when used as the iterator in a list comprehension inside the function.
313+
See: test_shadowed_ui_variable_threadpool in tests/_save/test_cache.py
314+
"""
315+
code = cleandoc(
316+
"""
317+
def helper(extension):
318+
return [e for e in extension or []]
319+
"""
320+
)
321+
v = visitor.ScopedVisitor()
322+
mod = ast.parse(code)
323+
v.visit(mod)
324+
325+
assert v.defs == {"helper"}
326+
assert v.refs == set() # No external refs!
327+
# extension is a PARAMETER, not an external dependency
328+
# Compare to test_globals_in_functions: foo(a...) where a is not in required_refs
329+
assert v.variable_data == {
330+
"helper": [VariableData(kind="function", required_refs=set())]
331+
}
332+
333+
334+
def test_nested_function_param_in_comprehension_not_required_ref() -> None:
335+
"""Ensure that additional nesting works."""
336+
code = cleandoc(
337+
"""
338+
def helper():
339+
extension = []
340+
def foo():
341+
def bar():
342+
return [e for e in extension or []]
343+
return bar
344+
return foo
345+
"""
346+
)
347+
v = visitor.ScopedVisitor()
348+
mod = ast.parse(code)
349+
v.visit(mod)
350+
351+
assert v.defs == {"helper"}
352+
assert v.refs == set() # No external refs!
353+
assert v.variable_data == {
354+
"helper": [VariableData(kind="function", required_refs=set())]
355+
}
356+
357+
358+
def test_param_in_comprehension_has_required_ref() -> None:
359+
"""Sanity check ref still is picked up"""
360+
code = cleandoc(
361+
"""
362+
def helper():
363+
return [e for e in extension or []]
364+
"""
365+
)
366+
v = visitor.ScopedVisitor()
367+
mod = ast.parse(code)
368+
v.visit(mod)
369+
370+
assert v.defs == {"helper"}
371+
assert v.refs == {"extension"}
372+
assert v.variable_data == {
373+
"helper": [VariableData(kind="function", required_refs={"extension"})]
374+
}
375+
376+
377+
def test_shadowed_param_in_comprehension_not_required_ref() -> None:
378+
"""Check that a shadowed variable doesn't capture ref in module scope."""
379+
code = cleandoc(
380+
"""
381+
extension = []
382+
def helper(extension):
383+
return [e for e in extension or []]
384+
"""
385+
)
386+
v = visitor.ScopedVisitor()
387+
mod = ast.parse(code)
388+
v.visit(mod)
389+
390+
assert v.defs == {"helper", "extension"}
391+
assert v.refs == set()
392+
assert v.variable_data == {
393+
"helper": [VariableData(kind="function", required_refs=set())],
394+
"extension": [VariableData(kind="variable", required_refs=set())],
395+
}
396+
397+
308398
def test_walrus_leaks_to_global_in_comprehension() -> None:
309399
code = "\n".join(
310400
[

tests/_save/test_cache.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,56 @@ def g():
25002500
assert g() == 2
25012501
return (g, arr)
25022502

2503+
@staticmethod
2504+
def test_shadowed_ui_variable_threadpool(app) -> None:
2505+
"""Test shadow error with UI-derived variable and ThreadPoolExecutor.
2506+
2507+
Bug requires:
2508+
1. UI element providing a cell-scoped variable (e.g. `extension`)
2509+
2. Helper function with same-named parameter using nested scope (list comp)
2510+
3. Cached function called via ThreadPoolExecutor.submit()
2511+
2512+
Causes KeyError at hash.py because scope has '*extension' (ARG_PREFIX)
2513+
but lookup uses 'extension'.
2514+
"""
2515+
with app.setup:
2516+
from concurrent.futures import ThreadPoolExecutor
2517+
2518+
import marimo as mo
2519+
2520+
@app.cell
2521+
def _():
2522+
ui_input = mo.ui.text(value="hello")
2523+
return (ui_input,)
2524+
2525+
@app.cell
2526+
def _():
2527+
def helper(extension: list[str] | None) -> int:
2528+
# Nested scope using extension triggers the bug
2529+
# for e in extension: ... works fine.
2530+
return len([e for e in extension or []])
2531+
2532+
@mo.cache
2533+
def inner(extension: list[str] | None) -> int:
2534+
assert len([e for e in extension or []]) == 5
2535+
return helper(extension)
2536+
2537+
return (inner,)
2538+
2539+
@app.cell
2540+
def _(inner, ui_input):
2541+
extension = ui_input.value
2542+
2543+
results = []
2544+
# has to be in a thread submission
2545+
# the following works fine
2546+
assert inner(extension) == 5
2547+
2548+
with ThreadPoolExecutor(max_workers=2) as executor:
2549+
future = executor.submit(inner, extension)
2550+
assert future.result() == 5
2551+
return
2552+
25032553

25042554
class TestPersistentCache:
25052555
async def test_pickle_context(

0 commit comments

Comments
 (0)