Skip to content

Commit 962be95

Browse files
authored
Merge pull request #79511 from swiftlang/78848-fix
Ensure we're commuting instructions only when internal one has a single use.
2 parents c433242 + 42e530f commit 962be95

File tree

2 files changed

+139
-5
lines changed

2 files changed

+139
-5
lines changed

lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
790790
//
791791
// This unblocks the `thin_to_thick_function` peephole optimization below.
792792
if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getOperand())) {
793-
if (CFI->getSingleUse()) {
793+
if (hasOneNonDebugUse(CFI)) {
794794
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getOperand())) {
795795
if (TTTFI->getSingleUse()) {
796796
auto convertedThickType = CFI->getType().castTo<SILFunctionType>();
@@ -836,7 +836,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
836836
// %vjp' = convert_escape_to_noescape %vjp
837837
// %y = differentiable_function(%orig', %jvp', %vjp')
838838
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getOperand())) {
839-
if (DFI->hasOneUse()) {
839+
if (hasOneNonDebugUse(DFI)) {
840840
auto createConvertEscapeToNoEscape =
841841
[&](NormalDifferentiableFunctionTypeComponent extractee) {
842842
if (!DFI->hasExtractee(extractee))
@@ -1020,9 +1020,7 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
10201020
// %vjp' = convert_function %vjp
10211021
// %y = differentiable_function(%orig', %jvp', %vjp')
10221022
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getOperand())) {
1023-
// Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78848
1024-
// TODO: remove this if-statement once the underlying problem is fixed.
1025-
if (cfi->getFunction()->hasOwnership())
1023+
if (!hasOneNonDebugUse(DFI))
10261024
return nullptr;
10271025

10281026
auto createConvertFunctionOfComponent =

test/AutoDiff/sil_combine.sil

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float):
5959
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
6060
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
6161
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
62+
63+
// MARK: `convert_function` hoisting
64+
65+
// This should optimize down single partial_apply that escapes
66+
sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
67+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
68+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
69+
70+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
71+
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
72+
73+
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
74+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
75+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
76+
}
77+
78+
debug_value %diff_fn, let, name "f", argno 1
79+
80+
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
81+
%conv_orig = differentiable_function_extract [original] %conv_diff
82+
return %conv_orig
83+
}
84+
85+
// CHECK-LABEL: sil @differential_function_convert_single_use
86+
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
87+
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
88+
// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
89+
// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
90+
91+
sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
92+
93+
// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
94+
95+
sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
96+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
97+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
98+
99+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
100+
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
101+
102+
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
103+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
104+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
105+
}
106+
107+
debug_value %diff_fn, let, name "f", argno 1
108+
109+
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
110+
%conv_orig = differentiable_function_extract [original] %conv_diff
111+
112+
%blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
113+
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
114+
115+
return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
116+
}
117+
118+
// CHECK-LABEL: sil @differential_function_convert_multiple_use
119+
// CHECK: convert_function
120+
// CHECK: differentiable_function
121+
// CHECK: convert_function
122+
// CHECK: differentiable_function_extract
123+
124+
// MARK: `convert_escape_to_noescape` hoisting
125+
126+
sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
127+
128+
// Here we should be able to unfold partial_apply down to direct function call
129+
130+
sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
131+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
132+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
133+
134+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
135+
136+
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
137+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
138+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
139+
}
140+
141+
debug_value %diff_fn, let, name "f", argno 1
142+
143+
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
144+
%conv_orig = differentiable_function_extract [original] %conv_diff
145+
146+
%arg = alloc_stack $Float
147+
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
148+
149+
dealloc_stack %arg : $*Float
150+
strong_release %pa
151+
152+
%res = tuple ()
153+
return %res : $()
154+
}
155+
156+
// CHECK-LABEL: sil @differential_function_noescape_single_use
157+
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
158+
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
159+
// CHECK: %[[ARG:.*]] = alloc_stack $Float
160+
// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
161+
162+
163+
// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
164+
165+
sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
166+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
167+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
168+
169+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
170+
171+
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
172+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
173+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
174+
}
175+
176+
debug_value %diff_fn, let, name "f", argno 1
177+
178+
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
179+
%conv_orig = differentiable_function_extract [original] %conv_diff
180+
181+
%arg = alloc_stack $Float
182+
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
183+
184+
%blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
185+
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
186+
187+
dealloc_stack %arg : $*Float
188+
strong_release %pa
189+
190+
%res = tuple ()
191+
return %res : $()
192+
}
193+
194+
// CHECK-LABEL: sil @differential_function_noescape_multiple_use
195+
// CHECK: differentiable_function
196+
// CHECK: convert_escape_to_noescape
197+
// CHECK: differentiable_function_extract

0 commit comments

Comments
 (0)