Skip to content

Mandatory optimizations: constant fold boolean literals before the DefiniteInitialization pass #70787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//===--- BooleanLiteralFolding.swift ---------------------------------------==//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import SIL

/// Constant folds conditional branches with boolean literals as operands.
///
/// ```
/// %1 = integer_literal -1
/// %2 = apply %bool_init(%1) // Bool.init(_builtinBooleanLiteral:)
/// %3 = struct_extract %2, #Bool._value
/// cond_br %3, bb1, bb2
/// ```
/// ->
/// ```
/// ...
/// br bb1
/// ```
///
/// This pass is intended to run before DefiniteInitialization, where mandatory inlining and
/// constant folding didn't run, yet (which would perform this kind of optimization).
///
/// This optimization is required to let DefiniteInitialization handle boolean literals correctly.
/// For example in infinite loops:
///
/// ```
/// init() {
/// while true { // DI need to know that there is no loop exit from this while-statement
/// if some_condition {
/// member_field = init_value
/// break
/// }
/// }
/// }
/// ```
///
let booleanLiteralFolding = FunctionPass(name: "boolean-literal-folding") {
(function: Function, context: FunctionPassContext) in

for block in function.blocks {
if let condBr = block.terminator as? CondBranchInst {
fold(condBranch: condBr, context)
}
}
}

private func fold(condBranch: CondBranchInst, _ context: FunctionPassContext) {
guard let structExtract = condBranch.condition as? StructExtractInst,
let initApply = structExtract.struct as? ApplyInst,
initApply.hasSemanticsAttribute("bool.literal_init"),
initApply.arguments.count == 2,
let literal = initApply.arguments[0] as? IntegerLiteralInst,
let literalValue = literal.value else
{
return
}

let builder = Builder(before: condBranch, context)
builder.createBranch(to: literalValue == 0 ? condBranch.falseBlock : condBranch.trueBlock)
context.erase(instruction: condBranch)
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ swift_compiler_sources(Optimizer
AllocVectorLowering.swift
AssumeSingleThreaded.swift
AsyncDemotion.swift
BooleanLiteralFolding.swift
CleanupDebugSteps.swift
ComputeEscapeEffects.swift
ComputeSideEffects.swift
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private func registerSwiftPasses() {
// Function passes
registerPass(allocVectorLowering, { allocVectorLowering.run($0) })
registerPass(asyncDemotion, { asyncDemotion.run($0) })
registerPass(booleanLiteralFolding, { booleanLiteralFolding.run($0) })
registerPass(letPropertyLowering, { letPropertyLowering.run($0) })
registerPass(mergeCondFailsPass, { mergeCondFailsPass.run($0) })
registerPass(computeEscapeEffects, { computeEscapeEffects.run($0) })
Expand Down
2 changes: 2 additions & 0 deletions include/swift/SILOptimizer/PassManager/Passes.def
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ PASS(BasicInstructionPropertyDumper, "basic-instruction-property-dump",
"Print SIL Instruction MemBehavior and ReleaseBehavior Information")
PASS(BasicCalleePrinter, "basic-callee-printer",
"Print Basic Callee Analysis for Testing")
SWIFT_FUNCTION_PASS(BooleanLiteralFolding, "boolean-literal-folding",
"Constant folds initializers of boolean literals")
PASS(CFGPrinter, "view-cfg",
"View Function CFGs")
PASS(COWArrayOpts, "cowarray-opt",
Expand Down
1 change: 1 addition & 0 deletions lib/SILOptimizer/PassManager/PassPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ static void addMandatoryDiagnosticOptPipeline(SILPassPipelinePlan &P) {

P.addAllocBoxToStack();
P.addNoReturnFolding();
P.addBooleanLiteralFolding();
addDefiniteInitialization(P);

P.addAddressLowering();
Expand Down
1 change: 1 addition & 0 deletions stdlib/public/core/Bool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public struct Bool: Sendable {

extension Bool: _ExpressibleByBuiltinBooleanLiteral, ExpressibleByBooleanLiteral {
@_transparent
@_semantics("bool.literal_init")
public init(_builtinBooleanLiteral value: Builtin.Int1) {
self._value = value
}
Expand Down
3 changes: 2 additions & 1 deletion test/AutoDiff/SILOptimizer/activity_analysis.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %target-swift-emit-sil -verify -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s
// TODO: re-enable the boolean-literal-folding pass and fix the test accordingly
// RUN: %target-swift-emit-sil -Xllvm -sil-disable-pass=boolean-literal-folding -verify -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s
// REQUIRES: asserts

import _Differentiation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ struct MyModel: Differentiable {
property2 = localVar

// `false` may instead be any expression that returns a `Bool`.
if false {
// TODO: cannot use literal `false` because it crashes
if 1 == 0 {
localVar = member3
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct BatchNorm<Scalar>: Layer { // Crash requires conformance to `Layer`
@differentiable(reverse)
func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
var offset = self.offset
if true { // Crash requires `if true`
// TODO: cannot use literal `true` because it crashes
if 1 == 1 { // Crash requires `if true`
offset += offset // Using `offset = offset + offset` stops the crash
}
return offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ AddressOnlyTangentVectorTests.test("LoadableClassAddressOnlyTangentVector") {
@differentiable(reverse)
func conditional<T: Differentiable>(_ s: LoadableClass<T>) -> T {
var tuple = (s, (s, s))
if false {}
// TODO: cannot use literal `false` because it crashes
if 1 == 0 {}
return tuple.1.0.stored
}
expectEqual(.init(stored: 1), gradient(at: LoadableClass<Float>(10), of: conditional))
Expand Down
22 changes: 14 additions & 8 deletions test/AutoDiff/validation-test/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ ControlFlowTests.test("Conditionals") {

func cond4_var(_ x: Float) -> Float {
var outer = x
outerIf: if true {
// TODO: cannot use literal `true` because it crashes
outerIf: if 1 == 1 {
var inner = outer
inner = inner * x
if false {
// TODO: cannot use literal `false` because it crashes
if 1 == 0 {
break outerIf
}
outer = inner
Expand Down Expand Up @@ -386,8 +388,9 @@ ControlFlowTests.test("NestedConditionals") {
@differentiable(reverse, wrt: self) // wrt only self is important
func callAsFunction(_ input: Float) -> Float {
var x = input
if true {
if true {
// TODO: cannot use literal `true` because it crashes
if 1 == 1 {
if 1 == 1 {
// Function application below should make `self` have non-zero
// derivative.
x = x * w
Expand All @@ -405,8 +408,9 @@ ControlFlowTests.test("NestedConditionals") {
@differentiable(reverse, wrt: x)
func TF_781(_ x: Float, _ y: Float) -> Float {
var result = y
if true {
if true {
// TODO: cannot use literal `true` because it crashes
if 1 == 1 {
if 1 == 1 {
result = result * x
}
}
Expand Down Expand Up @@ -791,7 +795,8 @@ ControlFlowTests.test("ThrowingCalls") {
func testComplexControlFlow(_ x: Float) -> Float {
rethrowing({})
for _ in 0..<Int(x) {
if true {
// TODO: cannot use literal `true` because it crashes
if 1 == 1 {
rethrowing({})
}
rethrowing({}) // non-active `try_apply`
Expand All @@ -805,7 +810,8 @@ ControlFlowTests.test("ThrowingCalls") {
func testComplexControlFlowGeneric<T: Differentiable>(_ x: T) -> T {
rethrowing({})
for _ in 0..<10 {
if true {
// TODO: cannot use literal `true` because it crashes
if 1 == 1 {
rethrowing({})
}
rethrowing({}) // non-active `try_apply`
Expand Down
2 changes: 1 addition & 1 deletion test/Macros/Inputs/syntax_macro_definitions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ public struct SimpleCodeItemMacro: CodeItemMacro {
}
""")),
.init(item: .stmt("""
if false {
if 1 == 0 {
print("impossible")
}
""")),
Expand Down
2 changes: 0 additions & 2 deletions test/Macros/macro_expand_codeitems.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ func testFreestandingMacroExpansion() {
// CHECK: from stmt
// CHECK: from usedInExpandedStmt
// CHECK: from expr
// CHECK-DIAGS: note: condition always evaluates to false
// CHECK-DIAGS: CONTENTS OF FILE @__swiftmacro_9MacroUser016testFreestandingA9ExpansionyyF9codeItemsfMf0_.swift:
// CHECK-DIAGS: struct $s9MacroUser016testFreestandingA9ExpansionyyF9codeItemsfMf0_3foofMu_ {
// CHECK-DIAGS: END CONTENTS OF FILE
#codeItems
Expand Down
6 changes: 3 additions & 3 deletions test/SILGen/availability_query.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func testUnreachableVersionAvailable(condition: Bool) {
if true {
doThing() // no-warning
}
if false { // expected-note {{condition always evaluates to false}}
if 1 == 0 { // expected-note {{condition always evaluates to false}}
doThing() // expected-warning {{will never be executed}}
}
}
Expand All @@ -144,8 +144,8 @@ func testUnreachablePlatformAvailable(condition: Bool) {
if true {
doThing() // no-warning
}
if false { // expected-note {{condition always evaluates to false}}
doThing() // expected-warning {{will never be executed}}
if false {
doThing()
}
}

Expand Down
136 changes: 136 additions & 0 deletions test/SILOptimizer/boolean-literal-folding.sil
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// RUN: %target-sil-opt -enable-sil-verify-all %s -boolean-literal-folding | %FileCheck %s

// REQUIRES: swift_in_compiler

sil_stage canonical

import Builtin
import Swift

sil public_external [_semantics "bool.literal_init"] @bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
sil public_external @no_bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
sil public_external [_semantics "bool.literal_init"] @wrong_bool_literal_init : $@convention(thin) () -> Bool

// CHECK-LABEL: sil [ossa] @replace_true :
// CHECK: struct_extract
// CHECK-NEXT: br bb1
// CHECK: bb1:
// CHECK: } // end sil function 'replace_true'
sil [ossa] @replace_true : $@convention(thin) () -> () {
bb0:
%0 = integer_literal $Builtin.Int1, -1
%1 = metatype $@thin Bool.Type
%2 = function_ref @bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%3 = apply %2(%0, %1) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%4 = struct_extract %3 : $Bool, #Bool._value
cond_br %4, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

// CHECK-LABEL: sil [ossa] @replace_false :
// CHECK: struct_extract
// CHECK-NEXT: br bb2
// CHECK: bb1:
// CHECK: } // end sil function 'replace_false'
sil [ossa] @replace_false : $@convention(thin) () -> () {
bb0:
%0 = integer_literal $Builtin.Int1, 0
%1 = metatype $@thin Bool.Type
%2 = function_ref @bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%3 = apply %2(%0, %1) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%4 = struct_extract %3 : $Bool, #Bool._value
cond_br %4, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

// CHECK-LABEL: sil [ossa] @dont_replace_non_literal :
// CHECK: cond_br %0, bb1, bb2
// CHECK: bb1:
// CHECK: } // end sil function 'dont_replace_non_literal'
sil [ossa] @dont_replace_non_literal : $@convention(thin) (Builtin.Int1) -> () {
bb0(%0 : $Builtin.Int1):
cond_br %0, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

// CHECK-LABEL: sil [ossa] @dont_replace_non_init_func :
// CHECK: [[B:%.*]] = struct_extract
// CHECK: cond_br [[B]], bb1, bb2
// CHECK: bb1:
// CHECK: } // end sil function 'dont_replace_non_init_func'
sil [ossa] @dont_replace_non_init_func : $@convention(thin) () -> () {
bb0:
%0 = integer_literal $Builtin.Int1, -1
%1 = metatype $@thin Bool.Type
%2 = function_ref @no_bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%3 = apply %2(%0, %1) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%4 = struct_extract %3 : $Bool, #Bool._value
cond_br %4, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

// CHECK-LABEL: sil [ossa] @dont_replace_wrong_init_func :
// CHECK: [[B:%.*]] = struct_extract
// CHECK: cond_br [[B]], bb1, bb2
// CHECK: bb1:
// CHECK: } // end sil function 'dont_replace_wrong_init_func'
sil [ossa] @dont_replace_wrong_init_func : $@convention(thin) () -> () {
bb0:
%2 = function_ref @wrong_bool_literal_init : $@convention(thin) () -> Bool
%3 = apply %2() : $@convention(thin) () -> Bool
%4 = struct_extract %3 : $Bool, #Bool._value
cond_br %4, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

// CHECK-LABEL: sil [ossa] @dont_replace_non_literal_init :
// CHECK: [[B:%.*]] = struct_extract
// CHECK: cond_br [[B]], bb1, bb2
// CHECK: bb1:
// CHECK: } // end sil function 'dont_replace_non_literal_init'
sil [ossa] @dont_replace_non_literal_init : $@convention(thin) (Builtin.Int1) -> () {
bb0(%0 : $Builtin.Int1):
%1 = metatype $@thin Bool.Type
%2 = function_ref @bool_literal_init : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%3 = apply %2(%0, %1) : $@convention(method) (Builtin.Int1, @thin Bool.Type) -> Bool
%4 = struct_extract %3 : $Bool, #Bool._value
cond_br %4, bb1, bb2
bb1:
br bb3
bb2:
br bb3
bb3:
%r = tuple ()
return %r : $()
}

Loading