-
Notifications
You must be signed in to change notification settings - Fork 4.2k
[export] backed_size_oblivious tutorial #3400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -489,6 +489,7 @@ def forward(self, w, x, y, z): | |||||
# specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens | ||||||
# at runtime when we export this linear layer: | ||||||
|
||||||
torch._logging.set_logs(dynamic=0) | ||||||
ep = export( | ||||||
torch.nn.Linear(4, 3), | ||||||
(torch.randn(1, 4),), | ||||||
|
@@ -591,6 +592,30 @@ def forward(self, x, y): | |||||
"bool_val": None, | ||||||
} | ||||||
|
||||||
###################################################################### | ||||||
# (experimental) Avoiding 0/1 specialization | ||||||
# ^^^^^^^^^^^^^^^^^^ | ||||||
# | ||||||
# Export provides an experimental option to avoid specializing on size 0/1 sample inputs. Users can turn on `torch.fx.experimental._config.backed_size_oblivious = True` to enable this behavior. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# This allows the compiler to allocate a [0, inf] range for symbols, and assume general-case semantics in compiler decisions between semantics for size 0/1 and >= 2 sizes. | ||||||
# This can lead to behavior divergence between eager mode and the exported program on size 0/1 inputs - for example, in broadcasting decisions, we will assume input shapes are not 1-specialized, | ||||||
# and therefore assume broadcasting does not apply (even if it does on the particular sample inputs). The same logic applies for other semantics (e.g. contiguity), and size 0 tensors. | ||||||
# | ||||||
# The exact semantics under this flag are a work in progress, and usage is recommended only when the user is certain their model does not rely on 0/1-specialized semantics. | ||||||
# For now, export users can enable this with: | ||||||
|
||||||
class Foo(torch.nn.Module): | ||||||
def forward(self, x, y): | ||||||
return x + y # nothing special about size 0/1 here | ||||||
|
||||||
x = torch.randn(0, 1) | ||||||
y = torch.randn(1) | ||||||
dynamic_shapes = {"x": (Dim.AUTO, Dim.AUTO), "y": (Dim.AUTO,)} | ||||||
with torch.fx.experimental._config.patch(backed_size_oblivious=True): | ||||||
ep = export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print ep? |
||||||
ep.module()(torch.randn(8, 1), torch.randn(1)) | ||||||
ep.module()(torch.randn(5, 6), torch.randn(6)) | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wonder if it would be good to add an example where the behavior is weir after setting backed_size_oblivious. your example here is a little simple. |
||||||
###################################################################### | ||||||
# Data-dependent errors | ||||||
# --------------------- | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be added after the section on 0/1 specialization? or move the section on 0/1 specialization to right before here?