@@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float):
59
59
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
60
60
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
61
61
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
62
+
63
+ // MARK: `convert_function` hoisting
64
+
65
+ // This should optimize down single partial_apply that escapes
66
+ sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
67
+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
68
+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
69
+
70
+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
71
+ %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
72
+
73
+ %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
74
+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
75
+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
76
+ }
77
+
78
+ debug_value %diff_fn, let, name "f", argno 1
79
+
80
+ %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
81
+ %conv_orig = differentiable_function_extract [original] %conv_diff
82
+ return %conv_orig
83
+ }
84
+
85
+ // CHECK-LABEL: sil @differential_function_convert_single_use
86
+ // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
87
+ // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
88
+ // CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
89
+ // CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
90
+
91
+ sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
92
+
93
+ // differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
94
+
95
+ sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
96
+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
97
+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
98
+
99
+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
100
+ %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
101
+
102
+ %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
103
+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
104
+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
105
+ }
106
+
107
+ debug_value %diff_fn, let, name "f", argno 1
108
+
109
+ %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
110
+ %conv_orig = differentiable_function_extract [original] %conv_diff
111
+
112
+ %blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
113
+ apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
114
+
115
+ return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
116
+ }
117
+
118
+ // CHECK-LABEL: sil @differential_function_convert_multiple_use
119
+ // CHECK: convert_function
120
+ // CHECK: differentiable_function
121
+ // CHECK: convert_function
122
+ // CHECK: differentiable_function_extract
123
+
124
+ // MARK: `convert_escape_to_noescape` hoisting
125
+
126
+ sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
127
+
128
+ // Here we should be able to unfold partial_apply down to direct function call
129
+
130
+ sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
131
+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
132
+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
133
+
134
+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
135
+
136
+ %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
137
+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
138
+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
139
+ }
140
+
141
+ debug_value %diff_fn, let, name "f", argno 1
142
+
143
+ %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
144
+ %conv_orig = differentiable_function_extract [original] %conv_diff
145
+
146
+ %arg = alloc_stack $Float
147
+ apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
148
+
149
+ dealloc_stack %arg : $*Float
150
+ strong_release %pa
151
+
152
+ %res = tuple ()
153
+ return %res : $()
154
+ }
155
+
156
+ // CHECK-LABEL: sil @differential_function_noescape_single_use
157
+ // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
158
+ // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
159
+ // CHECK: %[[ARG:.*]] = alloc_stack $Float
160
+ // CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
161
+
162
+
163
+ // differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
164
+
165
+ sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
166
+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
167
+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
168
+
169
+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
170
+
171
+ %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
172
+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
173
+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
174
+ }
175
+
176
+ debug_value %diff_fn, let, name "f", argno 1
177
+
178
+ %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
179
+ %conv_orig = differentiable_function_extract [original] %conv_diff
180
+
181
+ %arg = alloc_stack $Float
182
+ apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
183
+
184
+ %blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
185
+ apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
186
+
187
+ dealloc_stack %arg : $*Float
188
+ strong_release %pa
189
+
190
+ %res = tuple ()
191
+ return %res : $()
192
+ }
193
+
194
+ // CHECK-LABEL: sil @differential_function_noescape_multiple_use
195
+ // CHECK: differentiable_function
196
+ // CHECK: convert_escape_to_noescape
197
+ // CHECK: differentiable_function_extract
0 commit comments