From eb418b1ec330184a73eaeb15a7251385890b0f55 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 19 Aug 2022 13:44:12 -0500 Subject: [PATCH 01/14] Add an update_sources method to all Values --- mypyc/ir/ops.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 04c50d1e2841..991ba54a2ff7 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -235,6 +235,10 @@ def can_raise(self) -> bool: def sources(self) -> list[Value]: """All the values the op may read.""" + @abstractmethod + def set_sources(self, new: list[Value]) -> None: + """Rewrite the soruces of an op""" + def stolen(self) -> list[Value]: """Return arguments that have a reference count stolen by this op""" return [] @@ -271,6 +275,9 @@ def __init__(self, dest: Register, src: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: return [self.src] @@ -301,6 +308,9 @@ def __init__(self, dest: Register, src: list[Value], line: int = -1) -> None: def sources(self) -> list[Value]: return self.src.copy() + def set_sources(self, new: list[Value]) -> None: + self.src = new[:] + def stolen(self) -> list[Value]: return [] @@ -342,6 +352,9 @@ def __repr__(self) -> str: def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_goto(self) @@ -402,6 +415,9 @@ def set_target(self, i: int, new: BasicBlock) -> None: def sources(self) -> list[Value]: return [self.value] + def set_sources(self, new: list[Value]) -> None: + (self.value,) = new + def invert(self) -> None: self.negated = not self.negated @@ -421,6 +437,9 @@ def __init__(self, value: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.value] + def set_sources(self, new: list[Value]) -> None: + (self.value,) = new + def stolen(self) -> list[Value]: return [self.value] @@ -452,6 +471,9 @@ def __init__(self, line: int = -1) -> None: def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_unreachable(self) @@ -494,6 +516,9 @@ def __init__(self, src: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_inc_ref(self) @@ -519,6 +544,9 @@ def __repr__(self) -> str: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_dec_ref(self) @@ -544,6 +572,9 @@ def __init__(self, fn: FuncDecl, args: Sequence[Value], line: int) -> None: def sources(self) -> list[Value]: return list(self.args.copy()) + def set_sources(self, new: list[Value]) -> None: + self.args = new[:] + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_call(self) @@ -572,6 +603,9 @@ def __init__(self, obj: Value, method: str, args: list[Value], line: int = -1) - def sources(self) -> list[Value]: return self.args.copy() + [self.obj] + def set_sources(self, new: list[Value]) -> None: + *self.args, self.obj = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_method_call(self) @@ -599,6 +633,9 @@ def __init__( def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_error_value(self) @@ -631,6 +668,9 @@ def __init__(self, value: LiteralValue, rtype: RType) -> None: def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_literal(self) @@ -655,6 +695,9 @@ def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> def sources(self) -> list[Value]: return [self.obj] + def set_sources(self, new: list[Value]) -> None: + (self.obj,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_attr(self) @@ -687,6 +730,9 @@ def mark_as_initializer(self) -> None: def sources(self) -> list[Value]: return [self.obj, self.src] + def set_sources(self, new: list[Value]) -> None: + self.obj, self.src = new + def stolen(self) -> list[Value]: return [self.src] @@ -737,6 +783,9 @@ def __init__( def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_static(self) @@ -766,6 +815,9 @@ def __init__( def sources(self) -> list[Value]: return [self.value] + def set_sources(self, new: list[Value]) -> None: + (self.value,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_init_static(self) @@ -795,6 +847,9 @@ def sources(self) -> list[Value]: def stolen(self) -> list[Value]: return self.items.copy() + def set_sources(self, new: list[Value]) -> None: + self.items = new[:] + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_tuple_set(self) @@ -816,6 +871,9 @@ def __init__(self, src: Value, index: int, line: int = -1, *, borrow: bool = Fal def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_tuple_get(self) @@ -839,6 +897,9 @@ def __init__(self, src: Value, typ: RType, line: int, *, borrow: bool = False) - def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: if self.is_borrowed: return [] @@ -872,6 +933,9 @@ def __init__(self, src: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: return [self.src] @@ -898,6 +962,9 @@ def __init__(self, src: Value, typ: RType, line: int) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_unbox(self) @@ -930,6 +997,9 @@ def __init__(self, class_name: str, value: str | Value | None, line: int) -> Non def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + assert not new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_raise_standard_error(self) @@ -968,7 +1038,10 @@ def __init__( self.var_arg_idx = var_arg_idx def sources(self) -> list[Value]: - return self.args + return self.args[:] + + def set_sources(self, new: list[Value]) -> None: + self.args = new[:] def stolen(self) -> list[Value]: if isinstance(self.steals, list): @@ -1001,6 +1074,9 @@ def __init__(self, src: Value, dst_type: RType, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: return [] @@ -1032,6 +1108,9 @@ def __init__(self, src: Value, dst_type: RType, signed: bool, line: int = -1) -> def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: return [] @@ -1059,6 +1138,9 @@ def __init__(self, type: RType, identifier: str, line: int = -1, ann: object = N def sources(self) -> list[Value]: return [] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_global(self) @@ -1115,6 +1197,9 @@ def __init__(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) def sources(self) -> list[Value]: return [self.lhs, self.rhs] + def set_sources(self, new: list[Value]) -> None: + self.lhs, self.rhs = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_int_op(self) @@ -1178,6 +1263,9 @@ def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.lhs, self.rhs] + def set_sources(self, new: list[Value]) -> None: + self.lhs, self.rhs = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_comparison_op(self) @@ -1211,6 +1299,9 @@ def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.lhs, self.rhs] + def set_sources(self, new: list[Value]) -> None: + (self.lhs, self.rhs) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_float_op(self) @@ -1233,6 +1324,9 @@ def __init__(self, src: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_float_neg(self) @@ -1261,6 +1355,9 @@ def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.lhs, self.rhs] + def set_sources(self, new: list[Value]) -> None: + (self.lhs, self.rhs) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_float_comparison_op(self) @@ -1292,6 +1389,9 @@ def __init__(self, type: RType, src: Value, line: int = -1) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_mem(self) @@ -1317,6 +1417,9 @@ def __init__(self, type: RType, dest: Value, src: Value, line: int = -1) -> None def sources(self) -> list[Value]: return [self.src, self.dest] + def set_sources(self, new: list[Value]) -> None: + self.src, self.dest = new + def stolen(self) -> list[Value]: return [self.src] @@ -1343,6 +1446,9 @@ def __init__(self, src: Value, src_type: RType, field: str, line: int = -1) -> N def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_element_ptr(self) @@ -1371,6 +1477,11 @@ def sources(self) -> list[Value]: else: return [] + def set_sources(self, new: list[Value]) -> None: + if new: + assert isinstance(new[0], Register) + self.src = new[0] + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_address(self) @@ -1415,6 +1526,9 @@ def stolen(self) -> list[Value]: return self.src.copy() return [] + def set_sources(self, new: list[Value]) -> None: + self.src = new[:] + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_keep_alive(self) @@ -1454,6 +1568,9 @@ def __init__(self, src: Value) -> None: def sources(self) -> list[Value]: return [self.src] + def set_sources(self, new: list[Value]) -> None: + (self.src,) = new + def stolen(self) -> list[Value]: return [] From a63b61057525b876869c04edb9836639b36020c7 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 19 Aug 2022 16:00:50 -0500 Subject: [PATCH 02/14] Remove some nonsense code generated by uninit --- mypyc/transform/uninit.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mypyc/transform/uninit.py b/mypyc/transform/uninit.py index 6bf71ac4a8bc..45b403588f8e 100644 --- a/mypyc/transform/uninit.py +++ b/mypyc/transform/uninit.py @@ -69,14 +69,19 @@ def split_blocks_at_uninits( and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR) and not isinstance(op, LoadAddress) ): - new_block, error_block = BasicBlock(), BasicBlock() - new_block.error_handler = error_block.error_handler = cur_block.error_handler - new_blocks += [error_block, new_block] - if src not in init_registers_set: init_registers.append(src) init_registers_set.add(src) + # XXX: if src.name is empty, it should be a + # temp... and it should be OK?? + if not src.name: + continue + + new_block, error_block = BasicBlock(), BasicBlock() + new_block.error_handler = error_block.error_handler = cur_block.error_handler + new_blocks += [error_block, new_block] + if not src.type.error_overlap: cur_block.ops.append( Branch( From 7efb8dbc2dab63226fe1449328bccfaabf659fc8 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 19 Aug 2022 14:11:51 -0500 Subject: [PATCH 03/14] WIP: start on spilling --- mypyc/analysis/dataflow.py | 20 ++++-- mypyc/codegen/emitmodule.py | 9 +++ mypyc/ir/class_ir.py | 7 ++ mypyc/ir/ops.py | 9 ++- mypyc/irbuild/function.py | 1 + mypyc/irbuild/generator.py | 2 + mypyc/irbuild/statement.py | 2 +- mypyc/test-data/run-generators.test | 18 +++++ mypyc/transform/spill.py | 108 ++++++++++++++++++++++++++++ 9 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 mypyc/transform/spill.py diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index cade0c823962..a5ad460ba56c 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -17,6 +17,7 @@ Cast, ComparisonOp, ControlOp, + DecRef, Extend, Float, FloatComparisonOp, @@ -25,6 +26,7 @@ GetAttr, GetElementPtr, Goto, + IncRef, InitStatic, Integer, IntOp, @@ -79,12 +81,11 @@ def __str__(self) -> str: return "\n".join(lines) -def get_cfg(blocks: list[BasicBlock]) -> CFG: +def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG: """Calculate basic block control-flow graph. - The result is a dictionary like this: - - basic block index -> (successors blocks, predecesssor blocks) + If use_yields is set, then we treat returns inserted by yields as gotos + instead of exits. """ succ_map = {} pred_map: dict[BasicBlock, list[BasicBlock]] = {} @@ -94,7 +95,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG: isinstance(op, ControlOp) for op in block.ops[:-1] ), "Control-flow ops must be at the end of blocks" - succ = list(block.terminator.targets()) + if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target: + succ = [block.terminator.yield_target] + else: + succ = list(block.terminator.targets()) if not succ: exits.add(block) @@ -494,6 +498,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return non_trivial_sources(op), set() + def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]: + return set(), set() + + def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]: + return set(), set() + def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]: """Calculate live registers at each CFG location. diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index caf2058ea7c4..a1211fff68bf 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -58,6 +58,7 @@ from mypyc.options import CompilerOptions from mypyc.transform.exceptions import insert_exception_handling from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.spill import insert_spills from mypyc.transform.uninit import insert_uninit_checks # All of the modules being compiled are divided into "groups". A group @@ -225,6 +226,10 @@ def compile_scc_to_ir( if errors.num_errors > 0: return modules + # XXX: HOW WILL WE DEAL WITH REFCOUNTING ON THE SPILLAGE + # DO WE DO IT... LAST? MAYBE MAYBE MAYBE YES + # ONLY DO UNINIT.... YEAH OK + # Insert uninit checks. for module in modules.values(): for fn in module.functions: @@ -237,6 +242,10 @@ def compile_scc_to_ir( for module in modules.values(): for fn in module.functions: insert_ref_count_opcodes(fn) + for module in modules.values(): + for cls in module.classes: + if cls.env_user_function: + insert_spills(cls.env_user_function, cls) return modules diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 61f0fc36e1b3..11f308dc8b66 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -194,6 +194,9 @@ def __init__( # value of an attribute is the same as the error value. self.bitmap_attrs: list[str] = [] + # If this is a generator environment class, what is the actual method for it + self.env_user_function: FuncIR | None = None + def __repr__(self) -> str: return ( "ClassIR(" @@ -391,6 +394,7 @@ def serialize(self) -> JsonDict: "_always_initialized_attrs": sorted(self._always_initialized_attrs), "_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs), "init_self_leak": self.init_self_leak, + "env_user_function": self.env_user_function.id if self.env_user_function else None, } @classmethod @@ -442,6 +446,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR: ir._always_initialized_attrs = set(data["_always_initialized_attrs"]) ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"]) ir.init_self_leak = data["init_self_leak"] + ir.env_user_function = ( + ctx.functions[data["env_user_function"]] if data["env_user_function"] else None + ) return ir diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 991ba54a2ff7..6900b7ebc692 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -430,9 +430,16 @@ class Return(ControlOp): error_kind = ERR_NEVER - def __init__(self, value: Value, line: int = -1) -> None: + def __init__( + self, value: Value, line: int = -1, *, yield_target: BasicBlock | None = None + ) -> None: super().__init__(line) self.value = value + # If this return is created by a yield, keep track of the next + # basic block. This doesn't affect the code we generate but + # can feed into analysis that need to understand the + # *original* CFG. + self.yield_target = yield_target def sources(self) -> list[Value]: return [self.value] diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index ebf7fa9a54de..e03b534cc1d7 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -269,6 +269,7 @@ def c() -> None: # Re-enter the FuncItem and visit the body of the function this time. builder.enter(fn_info) setup_env_for_generator_class(builder) + load_outer_envs(builder, builder.fn_info.generator_class) top_level = builder.top_level_fn_info() if ( diff --git a/mypyc/irbuild/generator.py b/mypyc/irbuild/generator.py index 92f9abff467c..bc61c4493d55 100644 --- a/mypyc/irbuild/generator.py +++ b/mypyc/irbuild/generator.py @@ -181,6 +181,8 @@ def add_helper_to_generator_class( ) fn_info.generator_class.ir.methods["__mypyc_generator_helper__"] = helper_fn_ir builder.functions.append(helper_fn_ir) + fn_info.env_class.env_user_function = helper_fn_ir + return helper_fn_decl diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index d7e01456139d..98b81219fde9 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -894,7 +894,7 @@ def emit_yield(builder: IRBuilder, val: Value, line: int) -> Value: next_label = len(cls.continuation_blocks) cls.continuation_blocks.append(next_block) builder.assign(cls.next_label_target, Integer(next_label), line) - builder.add(Return(retval)) + builder.add(Return(retval, yield_target=next_block)) builder.activate_block(next_block) add_raise_exception_blocks_to_generator_class(builder, line) diff --git a/mypyc/test-data/run-generators.test b/mypyc/test-data/run-generators.test index bcf9da1846ae..84c7cd90b6a2 100644 --- a/mypyc/test-data/run-generators.test +++ b/mypyc/test-data/run-generators.test @@ -679,3 +679,21 @@ def test_basic() -> None: with context: assert context.x == 1 assert context.x == 0 + + +[case testYieldSpill] +from typing import Generator + +def f() -> int: + return 1 + +def yield_spill() -> Generator[str, int, int]: + return f() + (yield "foo") + +[file driver.py] +from native import yield_spill +from testutil import run_generator + +yields, val = run_generator(yield_spill(), [2]) +assert yields == ('foo',) +assert val == 3, val diff --git a/mypyc/transform/spill.py b/mypyc/transform/spill.py new file mode 100644 index 000000000000..5835a196a34b --- /dev/null +++ b/mypyc/transform/spill.py @@ -0,0 +1,108 @@ +"""Insert spills for values that are live across yields.""" + +from __future__ import annotations + +from mypyc.analysis.dataflow import AnalysisResult, analyze_live_regs, get_cfg +from mypyc.common import TEMP_ATTR_NAME +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncIR +from mypyc.ir.ops import ( + BasicBlock, + Branch, + DecRef, + GetAttr, + IncRef, + LoadErrorValue, + Register, + SetAttr, + Value, +) + + +def insert_spills(ir: FuncIR, env: ClassIR) -> None: + cfg = get_cfg(ir.blocks, use_yields=True) + live = analyze_live_regs(ir.blocks, cfg) + entry_live = live.before[ir.blocks[0], 0] + + # from mypyc.ir.pprint import format_func + + # print('\n'.join(format_func(ir))) + + entry_live = {op for op in entry_live if not (isinstance(op, Register) and op.is_arg)} + # XXX: Actually for now, no Registers at all -- we keep the manual spills + entry_live = {op for op in entry_live if not isinstance(op, Register)} + + ir.blocks = spill_regs(ir.blocks, env, entry_live, live) + # print("\n".join(format_func(ir))) + # print("\n\n\n=========") + + +def spill_regs( + blocks: list[BasicBlock], env: ClassIR, to_spill: set[Value], live: AnalysisResult[Value] +) -> list[BasicBlock]: + for op in blocks[0].ops: + if isinstance(op, GetAttr) and op.attr == "__mypyc_env__": + env_reg = op + break + else: + raise AssertionError("could not find __mypyc_env__") + + spill_locs = {} + for i, val in enumerate(to_spill): + name = f"{TEMP_ATTR_NAME}2_{i}" + env.attributes[name] = val.type + spill_locs[val] = name + + for block in blocks: + ops = block.ops + block.ops = [] + + for i, op in enumerate(ops): + to_decref = [] + + if isinstance(op, IncRef) and op.src in spill_locs: + raise AssertionError("not sure what to do with an incref of a spill...") + if isinstance(op, DecRef) and op.src in spill_locs: + # When we decref a spilled value, we turn that into + # NULLing out the attribute, but only if the spilled + # value is not live *when we include yields in the + # CFG*. (The original decrefs are computed without that.) + # + # We also skip a decref is the env register is not + # live. That should only happen when an exception is + # being raised, so everything should be handled there. + if op.src not in live.after[block, i] and env_reg in live.after[block, i]: + # Skip the DecRef but null out the spilled location + null = LoadErrorValue(op.src.type) + block.ops.extend([null, SetAttr(env_reg, spill_locs[op.src], null, op.line)]) + continue + + if ( + any(src in spill_locs for src in op.sources()) + # N.B: IS_ERROR should be before a spill happens + # XXX: but could we have a regular branch? + and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR) + ): + new_sources: list[Value] = [] + for src in op.sources(): + if src in spill_locs: + read = GetAttr(env_reg, spill_locs[src], op.line) + block.ops.append(read) + new_sources.append(read) + if src.type.is_refcounted: + to_decref.append(read) + else: + new_sources.append(src) + + op.set_sources(new_sources) + + block.ops.append(op) + + for dec in to_decref: + block.ops.append(DecRef(dec)) + + if op in spill_locs: + # XXX: could we set uninit? + block.ops.append(SetAttr(env_reg, spill_locs[op], op, op.line)) + + return blocks From 66084cc2c410713057794381cff193a2923caa43 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Sat, 21 Oct 2023 14:00:07 -0700 Subject: [PATCH 04/14] cleanups --- mypyc/codegen/emitmodule.py | 4 ---- mypyc/transform/spill.py | 8 +------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index a1211fff68bf..a3d198038bd3 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -226,10 +226,6 @@ def compile_scc_to_ir( if errors.num_errors > 0: return modules - # XXX: HOW WILL WE DEAL WITH REFCOUNTING ON THE SPILLAGE - # DO WE DO IT... LAST? MAYBE MAYBE MAYBE YES - # ONLY DO UNINIT.... YEAH OK - # Insert uninit checks. for module in modules.values(): for fn in module.functions: diff --git a/mypyc/transform/spill.py b/mypyc/transform/spill.py index 5835a196a34b..331f1d3c1536 100644 --- a/mypyc/transform/spill.py +++ b/mypyc/transform/spill.py @@ -24,17 +24,11 @@ def insert_spills(ir: FuncIR, env: ClassIR) -> None: live = analyze_live_regs(ir.blocks, cfg) entry_live = live.before[ir.blocks[0], 0] - # from mypyc.ir.pprint import format_func - - # print('\n'.join(format_func(ir))) - entry_live = {op for op in entry_live if not (isinstance(op, Register) and op.is_arg)} - # XXX: Actually for now, no Registers at all -- we keep the manual spills + # TODO: Actually for now, no Registers at all -- we keep the manual spills entry_live = {op for op in entry_live if not isinstance(op, Register)} ir.blocks = spill_regs(ir.blocks, env, entry_live, live) - # print("\n".join(format_func(ir))) - # print("\n\n\n=========") def spill_regs( From bfa72815f3cac205b2b66e5f247cc4424aad7309 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Sat, 21 Oct 2023 15:31:58 -0700 Subject: [PATCH 05/14] make it a no driver test --- mypyc/test-data/run-async.test | 26 ++++++++++++++++++-------- mypyc/test-data/run-generators.test | 13 ++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 85ad172d61df..9348c4e16841 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -1,6 +1,6 @@ # async test cases (compile and run) -[case testAsync] +[case testRunAsync] import asyncio async def h() -> int: @@ -11,19 +11,29 @@ async def g() -> int: return await h() async def f() -> int: - return await g() + return await g() + 2 + +async def f2() -> int: + x = 0 + for i in range(2): + x += i + await f() + await g() + return x + +def test_1() -> None: + result = asyncio.run(f()) + assert result == 3 + +def test_2() -> None: + result = asyncio.run(f2()) + assert result == 9 [file asyncio/__init__.pyi] async def sleep(t: float) -> None: ... +# eh, we could use the real type but it doesn't seem important +def run(x: object) -> object: ... [typing fixtures/typing-full.pyi] -[file driver.py] -from native import f -import asyncio - -result = asyncio.run(f()) -assert result == 1 [case testAsyncWith] from testutil import async_val diff --git a/mypyc/test-data/run-generators.test b/mypyc/test-data/run-generators.test index 84c7cd90b6a2..c37de43bfc78 100644 --- a/mypyc/test-data/run-generators.test +++ b/mypyc/test-data/run-generators.test @@ -683,6 +683,7 @@ def test_basic() -> None: [case testYieldSpill] from typing import Generator +from testutil import run_generator def f() -> int: return 1 @@ -690,10 +691,8 @@ def f() -> int: def yield_spill() -> Generator[str, int, int]: return f() + (yield "foo") -[file driver.py] -from native import yield_spill -from testutil import run_generator - -yields, val = run_generator(yield_spill(), [2]) -assert yields == ('foo',) -assert val == 3, val +def test_basic() -> None: + x = run_generator(yield_spill(), [2]) + yields, val = x + assert yields == ('foo',) + assert val == 3, val From c90b607cfb2e2970b83cc570443e9c8a892d5b9e Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 15 Apr 2025 14:09:52 +0100 Subject: [PATCH 06/14] Insert spills before optimizations --- mypyc/codegen/emitmodule.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 2bb83a86d0ba..e9e60cff5517 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -229,6 +229,12 @@ def compile_scc_to_ir( if errors.num_errors > 0: return modules + env_user_functions = {} + for module in modules.values(): + for cls in module.classes: + if cls.env_user_function: + env_user_functions[cls.env_user_function] = cls + for module in modules.values(): for fn in module.functions: # Insert uninit checks. @@ -237,17 +243,16 @@ def compile_scc_to_ir( insert_exception_handling(fn) # Insert refcount handling. insert_ref_count_opcodes(fn) + + if fn in env_user_functions: + insert_spills(fn, env_user_functions[fn]) + # Switch to lower abstraction level IR. lower_ir(fn, compiler_options) # Perform optimizations. do_copy_propagation(fn, compiler_options) do_flag_elimination(fn, compiler_options) - - for module in modules.values(): - for cls in module.classes: - if cls.env_user_function: - insert_spills(cls.env_user_function, cls) - + return modules From 7b76b9050faa744b2b4a5033077ea0357989306a Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 15 Apr 2025 14:18:40 +0100 Subject: [PATCH 07/14] Add missing set_sources --- mypyc/codegen/emitmodule.py | 2 +- mypyc/ir/ops.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index e9e60cff5517..b8a19ac1d669 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -252,7 +252,7 @@ def compile_scc_to_ir( # Perform optimizations. do_copy_propagation(fn, compiler_options) do_flag_elimination(fn, compiler_options) - + return modules diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index d6dcdb3468f3..598304e5caff 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -692,6 +692,9 @@ def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1 def sources(self) -> list[Value]: return self.args + def set_sources(self, new: list[Value]) -> None: + self.args = new[:] + def stolen(self) -> list[Value]: steals = self.desc.steals if isinstance(steals, list): From 92ab89b41da1003ce5db4e8dc1b5ecccb3c29c51 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 15 Apr 2025 14:30:06 +0100 Subject: [PATCH 08/14] Fix typo --- mypyc/ir/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 598304e5caff..b0cb127830e1 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -238,7 +238,7 @@ def sources(self) -> list[Value]: @abstractmethod def set_sources(self, new: list[Value]) -> None: - """Rewrite the soruces of an op""" + """Rewrite the sources of an op""" def stolen(self) -> list[Value]: """Return arguments that have a reference count stolen by this op""" From 3b64cc4eb4754101acec730b9dc1514ea472e39d Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 15 Apr 2025 16:40:47 +0100 Subject: [PATCH 09/14] Add tests --- mypyc/test-data/run-async.test | 70 +++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 3913924ae7ca..2eeaf5c9da84 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -1,6 +1,6 @@ # async test cases (compile and run) -[case testRunAsync] +[case testRunAsyncBasics] import asyncio async def h() -> int: @@ -34,6 +34,73 @@ def run(x: object) -> object: ... [typing fixtures/typing-full.pyi] +[case testRunAsyncAwaitInVariousPositions] +import asyncio + +async def one() -> int: + return int() + 1 + +async def true() -> bool: + return bool(int() + 1) + +async def branch_await() -> int: + if await true(): + return 3 + return 2 + +async def branch_await_not() -> int: + if not await true(): + return 3 + return 2 + +def test_branch() -> None: + assert asyncio.run(branch_await()) == 3 + assert asyncio.run(branch_await_not()) == 2 + +async def assign_local() -> int: + x = await one() + return x + 1 + +def test_assign_local() -> None: + assert asyncio.run(assign_local()) == 2 + +class C: + def __init__(self, s: str) -> None: + self.s = s + +async def make_c(s: str) -> C: + await one() + return C(s) + +async def get_attr(s: str) -> str: + return (await make_c(s)).s + +def test_get_attr() -> None: + assert asyncio.run(get_attr("foo")) == "foo" + +async def concat(s: str, t: str) -> str: + await one() + return s + t + +async def set_attr1(s: str) -> str: + c = await make_c("xyz") + c.s = await concat(s, "!") + return c.s + +async def set_attr2(s: str) -> None: + (await make_c("xyz")).s = s + +def test_set_attr() -> None: + assert asyncio.run(set_attr1("foo")) == "foo!" + asyncio.run(set_attr2("foo")) # Just check that it compiles and runs + +[file asyncio/__init__.pyi] +async def sleep(t: float) -> None: ... +# eh, we could use the real type but it doesn't seem important +def run(x: object) -> object: ... + +[typing fixtures/typing-full.pyi] + [case testAsyncWith] from testutil import async_val @@ -78,7 +145,6 @@ yields, val = run_generator(async_return()) assert yields == ('foo',) assert val == 'test', val - [case testAsyncFor] from typing import AsyncIterable, List, Set, Dict From 318ae87ae1751b65ee8bb6508f903e3396f8dfa3 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 15 Apr 2025 17:00:40 +0100 Subject: [PATCH 10/14] More tests --- mypyc/test-data/run-async.test | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 2eeaf5c9da84..289355eadf64 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -68,6 +68,9 @@ class C: def __init__(self, s: str) -> None: self.s = s + def concat(self, s: str) -> str: + return self.s + s + async def make_c(s: str) -> C: await one() return C(s) @@ -94,6 +97,33 @@ def test_set_attr() -> None: assert asyncio.run(set_attr1("foo")) == "foo!" asyncio.run(set_attr2("foo")) # Just check that it compiles and runs +def upper(s: str) -> str: + return s.upper() + +async def call1(s: str) -> str: + return upper(await concat(s, "a")) + +async def call2(s: str) -> str: + return await concat(await concat(s, "a"), "b") + +def test_call() -> None: + assert asyncio.run(call1("foo")) == "FOOA" + assert asyncio.run(call2("foo")) == "fooab" + +async def method_call(s: str) -> str: + c = C("<") + return c.concat(await concat(s, ">")) + +def test_method_call() -> None: + assert asyncio.run(method_call("foo")) == "" + +async def construct(s: str) -> str: + c = C(await concat(s, "!")) + return c.s + +def test_construct() -> None: + assert asyncio.run(construct("foo")) == "foo!" + [file asyncio/__init__.pyi] async def sleep(t: float) -> None: ... # eh, we could use the real type but it doesn't seem important From 0db01fe801e654f1e8522aca085578db4c597b8d Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 16 Apr 2025 10:07:26 +0100 Subject: [PATCH 11/14] More testing --- mypyc/test-data/run-async.test | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 289355eadf64..abedd4d10f55 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -35,13 +35,16 @@ def run(x: object) -> object: ... [typing fixtures/typing-full.pyi] [case testRunAsyncAwaitInVariousPositions] +from typing import cast, Any + import asyncio async def one() -> int: + await asyncio.sleep(0.0) return int() + 1 async def true() -> bool: - return bool(int() + 1) + return bool(int() + await one()) async def branch_await() -> int: if await true(): @@ -124,6 +127,33 @@ async def construct(s: str) -> str: def test_construct() -> None: assert asyncio.run(construct("foo")) == "foo!" +async def repr_as_object(s: str) -> object: + return repr(s) + +async def do_cast(s: str) -> str: + return cast(str, await repr_as_object(s)) + +def test_cast() -> None: + assert asyncio.run(do_cast("foo")) == "'foo'" + +async def box() -> list[int]: + return [await one(), await one()] + +def test_box() -> None: + assert asyncio.run(box()) == [1, 1] + +async def int_as_any(n: int) -> Any: + return n * 2 + +async def inc(n: int) -> int: + return n + await one() + +async def unbox(n: int) -> int: + return await inc(await int_as_any(n)) + +def test_unbox() -> None: + assert asyncio.run(unbox(4)) == 9 + [file asyncio/__init__.pyi] async def sleep(t: float) -> None: ... # eh, we could use the real type but it doesn't seem important From c3e87d2acd86d765d098e1ee7f20ed4a392a4bac Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 16 Apr 2025 11:02:28 +0100 Subject: [PATCH 12/14] Fixes --- mypyc/ir/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index b0cb127830e1..eec9c34a965e 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -1247,7 +1247,7 @@ def sources(self) -> list[Value]: return [] def set_sources(self, new: list[Value]) -> None: - (self.src,) = new + assert not new def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_load_global(self) @@ -1588,6 +1588,7 @@ def sources(self) -> list[Value]: def set_sources(self, new: list[Value]) -> None: if new: assert isinstance(new[0], Register) + assert len(new) == 1 self.src = new[0] def accept(self, visitor: OpVisitor[T]) -> T: From 9fb97f0ea745d9203fa6f8a72a8ef37a69c13af9 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 16 Apr 2025 16:43:03 +0100 Subject: [PATCH 13/14] Update tests --- mypyc/test-data/run-async.test | 83 ++++++++++------------------------ 1 file changed, 24 insertions(+), 59 deletions(-) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index abedd4d10f55..b2c2fffb38b5 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -47,12 +47,12 @@ async def true() -> bool: return bool(int() + await one()) async def branch_await() -> int: - if await true(): + if bool(int() + 1) == await true(): return 3 return 2 async def branch_await_not() -> int: - if not await true(): + if bool(int() + 1) == (not await true()): return 3 return 2 @@ -60,12 +60,12 @@ def test_branch() -> None: assert asyncio.run(branch_await()) == 3 assert asyncio.run(branch_await_not()) == 2 -async def assign_local() -> int: - x = await one() +async def assign_multi() -> int: + _, x = int(), await one() return x + 1 -def test_assign_local() -> None: - assert asyncio.run(assign_local()) == 2 +def test_assign_multi() -> None: + assert asyncio.run(assign_multi()) == 2 class C: def __init__(self, s: str) -> None: @@ -78,81 +78,46 @@ async def make_c(s: str) -> C: await one() return C(s) -async def get_attr(s: str) -> str: - return (await make_c(s)).s - -def test_get_attr() -> None: - assert asyncio.run(get_attr("foo")) == "foo" - async def concat(s: str, t: str) -> str: await one() return s + t -async def set_attr1(s: str) -> str: - c = await make_c("xyz") - c.s = await concat(s, "!") - return c.s - -async def set_attr2(s: str) -> None: - (await make_c("xyz")).s = s +async def set_attr(s: str) -> None: + (await make_c("xyz")).s = await concat(s, "!") def test_set_attr() -> None: - assert asyncio.run(set_attr1("foo")) == "foo!" - asyncio.run(set_attr2("foo")) # Just check that it compiles and runs + asyncio.run(set_attr("foo")) # Just check that it compiles and runs -def upper(s: str) -> str: - return s.upper() +def concat2(x: str, y: str) -> str: + return x + y async def call1(s: str) -> str: - return upper(await concat(s, "a")) + return concat2(str(int()), await concat(s, "a")) async def call2(s: str) -> str: - return await concat(await concat(s, "a"), "b") + return await concat(str(int()), await concat(s, "b")) def test_call() -> None: - assert asyncio.run(call1("foo")) == "FOOA" - assert asyncio.run(call2("foo")) == "fooab" + assert asyncio.run(call1("foo")) == "0fooa" + assert asyncio.run(call2("foo")) == "0foob" async def method_call(s: str) -> str: - c = C("<") - return c.concat(await concat(s, ">")) + return C("<").concat(await concat(s, ">")) def test_method_call() -> None: assert asyncio.run(method_call("foo")) == "" +class D: + def __init__(self, a: str, b: str) -> None: + self.a = a + self.b = b + async def construct(s: str) -> str: - c = C(await concat(s, "!")) - return c.s + c = D(await concat(s, "!"), await concat(s, "?")) + return c.a + c.b def test_construct() -> None: - assert asyncio.run(construct("foo")) == "foo!" - -async def repr_as_object(s: str) -> object: - return repr(s) - -async def do_cast(s: str) -> str: - return cast(str, await repr_as_object(s)) - -def test_cast() -> None: - assert asyncio.run(do_cast("foo")) == "'foo'" - -async def box() -> list[int]: - return [await one(), await one()] - -def test_box() -> None: - assert asyncio.run(box()) == [1, 1] - -async def int_as_any(n: int) -> Any: - return n * 2 - -async def inc(n: int) -> int: - return n + await one() - -async def unbox(n: int) -> int: - return await inc(await int_as_any(n)) - -def test_unbox() -> None: - assert asyncio.run(unbox(4)) == 9 + assert asyncio.run(construct("foo")) == "foo!foo?" [file asyncio/__init__.pyi] async def sleep(t: float) -> None: ... From fdc84df498480857ff36cb8ced40029a6adc69eb Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Thu, 17 Apr 2025 11:00:54 +0100 Subject: [PATCH 14/14] Remove failing test case (will investigate it later) --- mypyc/test-data/run-async.test | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index b2c2fffb38b5..89d661900de0 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -74,20 +74,10 @@ class C: def concat(self, s: str) -> str: return self.s + s -async def make_c(s: str) -> C: - await one() - return C(s) - async def concat(s: str, t: str) -> str: await one() return s + t -async def set_attr(s: str) -> None: - (await make_c("xyz")).s = await concat(s, "!") - -def test_set_attr() -> None: - asyncio.run(set_attr("foo")) # Just check that it compiles and runs - def concat2(x: str, y: str) -> str: return x + y @@ -126,7 +116,6 @@ def run(x: object) -> object: ... [typing fixtures/typing-full.pyi] - [case testAsyncWith] from testutil import async_val