Skip to content

[mypyc] Switch to table-driven imports for smaller IR #14917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import pprint
import sys
import textwrap
from typing import Callable
from typing_extensions import Final

Expand Down Expand Up @@ -191,10 +193,31 @@ def reg(self, reg: Value) -> str:
def attr(self, name: str) -> str:
return ATTR_PREFIX + name

def emit_line(self, line: str = "") -> None:
def object_annotation(self, obj: object, line: str) -> str:
"""Build a C comment with an object's string represention.

If the comment exceeds the line length limit, it's wrapped into a
multiline string (with the extra lines indented to be aligned with
the first line's comment).

If it contains illegal characters, an empty string is returned."""
line_width = self._indent + len(line)
formatted = pprint.pformat(obj, compact=True, width=max(90 - line_width, 20))
if any(x in formatted for x in ("/*", "*/", "\0")):
return ""

if "\n" in formatted:
first_line, rest = formatted.split("\n", maxsplit=1)
comment_continued = textwrap.indent(rest, (line_width + 3) * " ")
return f" /* {first_line}\n{comment_continued} */"
else:
return f" /* {formatted} */"

def emit_line(self, line: str = "", *, ann: object = None) -> None:
if line.startswith("}"):
self.dedent()
self.fragments.append(self._indent * " " + line + "\n")
comment = self.object_annotation(ann, line) if ann is not None else ""
self.fragments.append(self._indent * " " + line + comment + "\n")
if line.endswith("{"):
self.indent()

Expand Down Expand Up @@ -1119,3 +1142,35 @@ def _emit_traceback(
self.emit_line(line)
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')


def c_array_initializer(components: list[str], *, indented: bool = False) -> str:
"""Construct an initializer for a C array variable.

Components are C expressions valid in an initializer.

For example, if components are ["1", "2"], the result
would be "{1, 2}", which can be used like this:

int a[] = {1, 2};

If the result is long, split it into multiple lines.
"""
indent = " " * 4 if indented else ""
res = []
current: list[str] = []
cur_len = 0
for c in components:
if not current or cur_len + 2 + len(indent) + len(c) < 70:
current.append(c)
cur_len += len(c) + 2
else:
res.append(indent + ", ".join(current))
current = [c]
cur_len = len(c)
if not res:
# Result fits on a single line
return "{%s}" % ", ".join(current)
# Multi-line result
res.append(indent + ", ".join(current))
return "{\n " + ",\n ".join(res) + "\n" + indent + "}"
43 changes: 18 additions & 25 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Final

from mypyc.analysis.blockfreq import frequently_executed_blocks
from mypyc.codegen.emit import DEBUG_ERRORS, Emitter, TracebackAndGotoHandler
from mypyc.codegen.emit import DEBUG_ERRORS, Emitter, TracebackAndGotoHandler, c_array_initializer
from mypyc.common import MODULE_PREFIX, NATIVE_PREFIX, REG_PREFIX, STATIC_PREFIX, TYPE_PREFIX
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FuncDecl, FuncIR, all_values
Expand Down Expand Up @@ -262,12 +262,12 @@ def visit_assign_multi(self, op: AssignMulti) -> None:
# RArray values can only be assigned to once, so we can always
# declare them on initialization.
self.emit_line(
"%s%s[%d] = {%s};"
"%s%s[%d] = %s;"
% (
self.emitter.ctype_spaced(typ.item_type),
dest,
len(op.src),
", ".join(self.reg(s) for s in op.src),
c_array_initializer([self.reg(s) for s in op.src], indented=True),
)
)

Expand All @@ -282,15 +282,12 @@ def visit_load_error_value(self, op: LoadErrorValue) -> None:

def visit_load_literal(self, op: LoadLiteral) -> None:
index = self.literals.literal_index(op.value)
s = repr(op.value)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
else:
ann = ""
if not is_int_rprimitive(op.type):
self.emit_line("%s = CPyStatics[%d];%s" % (self.reg(op), index, ann))
self.emit_line("%s = CPyStatics[%d];" % (self.reg(op), index), ann=op.value)
else:
self.emit_line("%s = (CPyTagged)CPyStatics[%d] | 1;%s" % (self.reg(op), index, ann))
self.emit_line(
"%s = (CPyTagged)CPyStatics[%d] | 1;" % (self.reg(op), index), ann=op.value
)

def get_attr_expr(self, obj: str, op: GetAttr | SetAttr, decl_cl: ClassIR) -> str:
"""Generate attribute accessor for normal (non-property) access.
Expand Down Expand Up @@ -468,12 +465,7 @@ def visit_load_static(self, op: LoadStatic) -> None:
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
name = "(PyObject *)%s" % name
ann = ""
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
self.emit_line(f"{dest} = {name};{ann}")
self.emit_line(f"{dest} = {name};", ann=op.ann)

def visit_init_static(self, op: InitStatic) -> None:
value = self.reg(op.value)
Expand Down Expand Up @@ -636,12 +628,7 @@ def visit_extend(self, op: Extend) -> None:

def visit_load_global(self, op: LoadGlobal) -> None:
dest = self.reg(op)
ann = ""
if op.ann:
s = repr(op.ann)
if not any(x in s for x in ("/*", "*/", "\0")):
ann = " /* %s */" % s
self.emit_line(f"{dest} = {op.identifier};{ann}")
self.emit_line(f"{dest} = {op.identifier};", ann=op.ann)

def visit_int_op(self, op: IntOp) -> None:
dest = self.reg(op)
Expand Down Expand Up @@ -727,7 +714,13 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None:
def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
src = self.reg(op.src) if isinstance(op.src, Register) else op.src
if isinstance(op.src, Register):
src = self.reg(op.src)
elif isinstance(op.src, LoadStatic):
prefix = self.PREFIX_MAP[op.src.namespace]
src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix)
else:
src = op.src
self.emit_line(f"{dest} = ({typ._ctype})&{src};")

def visit_keep_alive(self, op: KeepAlive) -> None:
Expand Down Expand Up @@ -776,8 +769,8 @@ def c_error_value(self, rtype: RType) -> str:
def c_undefined_value(self, rtype: RType) -> str:
return self.emitter.c_undefined_value(rtype)

def emit_line(self, line: str) -> None:
self.emitter.emit_line(line)
def emit_line(self, line: str, *, ann: object = None) -> None:
self.emitter.emit_line(line, ann=ann)

def emit_lines(self, *lines: str) -> None:
self.emitter.emit_lines(*lines)
Expand Down
56 changes: 13 additions & 43 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mypy.plugin import Plugin, ReportConfigContext
from mypy.util import hash_digest
from mypyc.codegen.cstring import c_string_initializer
from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration
from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration, c_array_initializer
from mypyc.codegen.emitclass import generate_class, generate_class_type_decl
from mypyc.codegen.emitfunc import generate_native_function, native_function_header
from mypyc.codegen.emitwrapper import (
Expand Down Expand Up @@ -296,11 +296,11 @@ def compile_ir_to_c(
# compiled into a separate extension module.
ctext: dict[str | None, list[tuple[str, str]]] = {}
for group_sources, group_name in groups:
group_modules = [
(source.module, modules[source.module])
group_modules = {
source.module: modules[source.module]
for source in group_sources
if source.module in modules
]
}
if not group_modules:
ctext[group_name] = []
continue
Expand Down Expand Up @@ -465,7 +465,7 @@ def group_dir(group_name: str) -> str:
class GroupGenerator:
def __init__(
self,
modules: list[tuple[str, ModuleIR]],
modules: dict[str, ModuleIR],
source_paths: dict[str, str],
group_name: str | None,
group_map: dict[str, str | None],
Expand Down Expand Up @@ -512,7 +512,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
multi_file = self.use_shared_lib and self.multi_file

# Collect all literal refs in IR.
for _, module in self.modules:
for module in self.modules.values():
for fn in module.functions:
collect_literals(fn, self.context.literals)

Expand All @@ -528,7 +528,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:

self.generate_literal_tables()

for module_name, module in self.modules:
for module_name, module in self.modules.items():
if multi_file:
emitter = Emitter(self.context)
emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
Expand Down Expand Up @@ -582,7 +582,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
declarations.emit_line("int CPyGlobalsInit(void);")
declarations.emit_line()

for module_name, module in self.modules:
for module_name, module in self.modules.items():
self.declare_finals(module_name, module.final_names, declarations)
for cl in module.classes:
generate_class_type_decl(cl, emitter, ext_declarations, declarations)
Expand Down Expand Up @@ -790,7 +790,7 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None:
"",
)

for mod, _ in self.modules:
for mod in self.modules:
name = exported_name(mod)
emitter.emit_lines(
f"extern PyObject *CPyInit_{name}(void);",
Expand Down Expand Up @@ -1023,12 +1023,13 @@ def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str
return emitter.static_name(module_name + "_internal", None, prefix=MODULE_PREFIX)

def declare_module(self, module_name: str, emitter: Emitter) -> None:
# We declare two globals for each module:
# We declare two globals for each compiled module:
# one used internally in the implementation of module init to cache results
# and prevent infinite recursion in import cycles, and one used
# by other modules to refer to it.
internal_static_name = self.module_internal_static_name(module_name, emitter)
self.declare_global("CPyModule *", internal_static_name, initializer="NULL")
if module_name in self.modules:
internal_static_name = self.module_internal_static_name(module_name, emitter)
self.declare_global("CPyModule *", internal_static_name, initializer="NULL")
static_name = emitter.static_name(module_name, None, prefix=MODULE_PREFIX)
self.declare_global("CPyModule *", static_name)
self.simple_inits.append((static_name, "Py_None"))
Expand Down Expand Up @@ -1126,37 +1127,6 @@ def collect_literals(fn: FuncIR, literals: Literals) -> None:
literals.record_literal(op.value)


def c_array_initializer(components: list[str]) -> str:
"""Construct an initializer for a C array variable.

Components are C expressions valid in an initializer.

For example, if components are ["1", "2"], the result
would be "{1, 2}", which can be used like this:

int a[] = {1, 2};

If the result is long, split it into multiple lines.
"""
res = []
current: list[str] = []
cur_len = 0
for c in components:
if not current or cur_len + 2 + len(c) < 70:
current.append(c)
cur_len += len(c) + 2
else:
res.append(", ".join(current))
current = [c]
cur_len = len(c)
if not res:
# Result fits on a single line
return "{%s}" % ", ".join(current)
# Multi-line result
res.append(", ".join(current))
return "{\n " + ",\n ".join(res) + "\n}"


def c_string_array_initializer(components: list[bytes]) -> str:
result = []
result.append("{\n")
Expand Down
5 changes: 3 additions & 2 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,13 +1348,14 @@ class LoadAddress(RegisterOp):
Attributes:
type: Type of the loaded address(e.g. ptr/object_ptr)
src: Source value (str for globals like 'PyList_Type',
Register for temporary values or locals)
Register for temporary values or locals, LoadStatic
for statics.)
"""

error_kind = ERR_NEVER
is_borrowed = True

def __init__(self, type: RType, src: str | Register, line: int = -1) -> None:
def __init__(self, type: RType, src: str | Register | LoadStatic, line: int = -1) -> None:
super().__init__(line)
self.type = type
self.src = src
Expand Down
5 changes: 5 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> str:
def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
elif isinstance(op.src, LoadStatic):
name = op.src.identifier
if op.src.module_name is not None:
name = f"{op.src.module_name}.{name}"
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
else:
return self.format("%r = load_address %s", op, op.src)

Expand Down
26 changes: 3 additions & 23 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
RType,
RUnion,
bitmap_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
dict_rprimitive,
int_rprimitive,
Expand Down Expand Up @@ -126,12 +125,7 @@
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op
from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list
from mypyc.primitives.misc_ops import (
check_unpack_count_op,
get_module_dict_op,
import_extra_args_op,
import_op,
)
from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op
from mypyc.primitives.registry import CFunctionDescription, function_ops

# These int binary operations can borrow their operands safely, since the
Expand Down Expand Up @@ -193,6 +187,8 @@ def __init__(
self.encapsulating_funcs = pbv.encapsulating_funcs
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.module_import_groups = pbv.module_import_groups

self.singledispatch_impls = singledispatch_impls

self.visitor = visitor
Expand Down Expand Up @@ -394,22 +390,6 @@ def add_to_non_ext_dict(
key_unicode = self.load_str(key)
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

def gen_import_from(
self, id: str, globals_dict: Value, imported: list[str], line: int
) -> Value:
self.imports[id] = None

null_dict = Integer(0, dict_rprimitive, line)
names_to_import = self.new_list_op([self.load_str(name) for name in imported], line)
zero_int = Integer(0, c_int_rprimitive, line)
value = self.call_c(
import_extra_args_op,
[self.load_str(id), globals_dict, null_dict, names_to_import, zero_int],
line,
)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
return value

def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None

Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def gen_func_ns(builder: IRBuilder) -> str:
return "_".join(
info.name + ("" if not info.class_name else "_" + info.class_name)
for info in builder.fn_infos
if info.name and info.name != "<top level>"
if info.name and info.name != "<module>"
)


Expand Down
Loading