Skip to content

[mlir][python] enable registering dialects with the default Context #72488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions mlir/python/mlir/_mlir_libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,24 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.


def get_registry():
if not hasattr(get_registry, "__registry"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if this is still a thing, but didn't it used to be that double underscore prefixed attributes were lexically mangled? I've got it stuck in my "never do that" category, but may be a legacy lint check in my brain :)

Copy link
Contributor Author

@makslevental makslevental Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true but that's only for class fields - here I'm doing something even "dirtier" and setting it on the function object so it doesn't get mangled.

But just this morning was thinking I'd refactor this to be a module global behind threading.local() instead of this hackery. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just move it up a level and use nonlocal. No need for dirty tricks that someone will need to grok later. I'm not sure it needs to be thread local. This is a really basic facility in the same vein as site_initialize, which is global-global.

Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.

Copy link
Contributor Author

@makslevental makslevental Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.

remove_site_initialize_1 changed things around a lot (didn't exactly remove but large refactor) and then I had the lightbulb moment that all I needed was this one helper and hence remove_site_initialize_2.

from ._mlir import ir

get_registry.__registry = ir.DialectRegistry()

return get_registry.__registry


def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir

logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
disable_multithreading = False

Expand All @@ -84,7 +94,7 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(registry)
m.register_dialects(get_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
Expand All @@ -110,7 +120,7 @@ def process_initializer_module(module_name):
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(registry)
self.append_dialect_registry(get_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
Expand Down
4 changes: 2 additions & 2 deletions mlir/python/mlir/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)


def register_python_test_dialect(context, load=True):
def register_python_test_dialect(registry):
from .._mlir_libs import _mlirPythonTest

_mlirPythonTest.register_python_test_dialect(context, load)
_mlirPythonTest.register_dialect(registry)
1 change: 1 addition & 0 deletions mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs import get_registry


# Convenience decorator for registering user-friendly Attribute builders.
Expand Down
16 changes: 2 additions & 14 deletions mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith

test.register_python_test_dialect(get_registry())


def run(f):
print("\nTEST:", f.__name__)
Expand All @@ -17,7 +19,6 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
Expand Down Expand Up @@ -138,7 +139,6 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
Expand Down Expand Up @@ -215,7 +215,6 @@ def attrBuilder():
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
Expand Down Expand Up @@ -260,7 +259,6 @@ def inferReturnTypes():
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
Expand Down Expand Up @@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
Expand All @@ -312,7 +308,6 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
Expand Down Expand Up @@ -350,7 +345,6 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestType.get()
# CHECK: !python_test.test_type
print(a)
Expand Down Expand Up @@ -397,8 +391,6 @@ def testCustomType():
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

i8 = IntegerType.get_signless(8)

class Tensor(test.TestTensorValue):
Expand Down Expand Up @@ -436,7 +428,6 @@ def __str__(self):
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
Expand Down Expand Up @@ -488,8 +479,6 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

a = test.TestType.get()
assert a.typeid is not None

Expand Down Expand Up @@ -542,7 +531,6 @@ def type_caster(pytype):
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/lib/PythonTestModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);

m.def(
"register_dialect",
[](MlirDialectRegistry registry) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
},
py::arg("registry"));

mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
Expand Down