Skip to content

Commit 42e530f

Browse files
committed
Add tests
1 parent 4310507 commit 42e530f

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

test/AutoDiff/sil_combine.sil

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float):
5959
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
6060
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
6161
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
62+
63+
// MARK: `convert_function` hoisting
64+
65+
// This should optimize down single partial_apply that escapes
66+
sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
67+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
68+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
69+
70+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
71+
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
72+
73+
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
74+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
75+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
76+
}
77+
78+
debug_value %diff_fn, let, name "f", argno 1
79+
80+
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
81+
%conv_orig = differentiable_function_extract [original] %conv_diff
82+
return %conv_orig
83+
}
84+
85+
// CHECK-LABEL: sil @differential_function_convert_single_use
86+
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
87+
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
88+
// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
89+
// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
90+
91+
sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
92+
93+
// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
94+
95+
sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
96+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
97+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
98+
99+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
100+
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
101+
102+
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
103+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
104+
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
105+
}
106+
107+
debug_value %diff_fn, let, name "f", argno 1
108+
109+
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
110+
%conv_orig = differentiable_function_extract [original] %conv_diff
111+
112+
%blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
113+
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
114+
115+
return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
116+
}
117+
118+
// CHECK-LABEL: sil @differential_function_convert_multiple_use
119+
// CHECK: convert_function
120+
// CHECK: differentiable_function
121+
// CHECK: convert_function
122+
// CHECK: differentiable_function_extract
123+
124+
// MARK: `convert_escape_to_noescape` hoisting
125+
126+
sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
127+
128+
// Here we should be able to unfold partial_apply down to direct function call
129+
130+
sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
131+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
132+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
133+
134+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
135+
136+
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
137+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
138+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
139+
}
140+
141+
debug_value %diff_fn, let, name "f", argno 1
142+
143+
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
144+
%conv_orig = differentiable_function_extract [original] %conv_diff
145+
146+
%arg = alloc_stack $Float
147+
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
148+
149+
dealloc_stack %arg : $*Float
150+
strong_release %pa
151+
152+
%res = tuple ()
153+
return %res : $()
154+
}
155+
156+
// CHECK-LABEL: sil @differential_function_noescape_single_use
157+
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
158+
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
159+
// CHECK: %[[ARG:.*]] = alloc_stack $Float
160+
// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
161+
162+
163+
// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
164+
165+
sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
166+
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
167+
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
168+
169+
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
170+
171+
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
172+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
173+
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
174+
}
175+
176+
debug_value %diff_fn, let, name "f", argno 1
177+
178+
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
179+
%conv_orig = differentiable_function_extract [original] %conv_diff
180+
181+
%arg = alloc_stack $Float
182+
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
183+
184+
%blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
185+
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
186+
187+
dealloc_stack %arg : $*Float
188+
strong_release %pa
189+
190+
%res = tuple ()
191+
return %res : $()
192+
}
193+
194+
// CHECK-LABEL: sil @differential_function_noescape_multiple_use
195+
// CHECK: differentiable_function
196+
// CHECK: convert_escape_to_noescape
197+
// CHECK: differentiable_function_extract

0 commit comments

Comments
 (0)