Skip to content

Commit 2743e18

Browse files
authored
fix: disable ForwardDiff tag checking with custom backend tags (#631)
* fix: disable ForwardDiff tag checking with custom backend tags * Fix sparse
1 parent fd60623 commit 2743e18

File tree

4 files changed

+88
-45
lines changed

4 files changed

+88
-45
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -305,44 +305,56 @@ function DI.value_and_gradient!(
305305
f::F,
306306
grad,
307307
prep::ForwardDiffGradientPrep,
308-
::AutoForwardDiff,
308+
backend::AutoForwardDiff,
309309
x,
310310
contexts::Vararg{Constant,C},
311311
) where {F,C}
312312
fc = with_contexts(f, contexts...)
313313
result = DiffResult(zero(eltype(x)), (grad,))
314-
result = gradient!(result, fc, x, prep.config)
314+
CHK = tag_type(backend) === Nothing
315+
result = gradient!(result, fc, x, prep.config, Val(CHK))
315316
y = DR.value(result)
316317
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
317318
return y, grad
318319
end
319320

320321
function DI.value_and_gradient(
321-
f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C}
322+
f::F,
323+
prep::ForwardDiffGradientPrep,
324+
backend::AutoForwardDiff,
325+
x,
326+
contexts::Vararg{Constant,C},
322327
) where {F,C}
323328
fc = with_contexts(f, contexts...)
324329
result = GradientResult(x)
325-
result = gradient!(result, fc, x, prep.config)
330+
CHK = tag_type(backend) === Nothing
331+
result = gradient!(result, fc, x, prep.config, Val(CHK))
326332
return DR.value(result), DR.gradient(result)
327333
end
328334

329335
function DI.gradient!(
330336
f::F,
331337
grad,
332338
prep::ForwardDiffGradientPrep,
333-
::AutoForwardDiff,
339+
backend::AutoForwardDiff,
334340
x,
335341
contexts::Vararg{Constant,C},
336342
) where {F,C}
337343
fc = with_contexts(f, contexts...)
338-
return gradient!(grad, fc, x, prep.config)
344+
CHK = tag_type(backend) === Nothing
345+
return gradient!(grad, fc, x, prep.config, Val(CHK))
339346
end
340347

341348
function DI.gradient(
342-
f::F, prep::ForwardDiffGradientPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C}
349+
f::F,
350+
prep::ForwardDiffGradientPrep,
351+
backend::AutoForwardDiff,
352+
x,
353+
contexts::Vararg{Constant,C},
343354
) where {F,C}
344355
fc = with_contexts(f, contexts...)
345-
return gradient(fc, x, prep.config)
356+
CHK = tag_type(backend) === Nothing
357+
return gradient(fc, x, prep.config, Val(CHK))
346358
end
347359

348360
## Jacobian
@@ -422,14 +434,15 @@ function DI.value_and_jacobian!(
422434
f::F,
423435
jac,
424436
prep::ForwardDiffOneArgJacobianPrep,
425-
::AutoForwardDiff,
437+
backend::AutoForwardDiff,
426438
x,
427439
contexts::Vararg{Constant,C},
428440
) where {F,C}
429441
fc = with_contexts(f, contexts...)
430442
y = fc(x)
431443
result = DiffResult(y, (jac,))
432-
result = jacobian!(result, fc, x, prep.config)
444+
CHK = tag_type(backend) === Nothing
445+
result = jacobian!(result, fc, x, prep.config, Val(CHK))
433446
y = DR.value(result)
434447
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
435448
return y, jac
@@ -438,35 +451,38 @@ end
438451
function DI.value_and_jacobian(
439452
f::F,
440453
prep::ForwardDiffOneArgJacobianPrep,
441-
::AutoForwardDiff,
454+
backend::AutoForwardDiff,
442455
x,
443456
contexts::Vararg{Constant,C},
444457
) where {F,C}
445458
fc = with_contexts(f, contexts...)
446-
return fc(x), jacobian(fc, x, prep.config)
459+
CHK = tag_type(backend) === Nothing
460+
return fc(x), jacobian(fc, x, prep.config, Val(CHK))
447461
end
448462

449463
function DI.jacobian!(
450464
f::F,
451465
jac,
452466
prep::ForwardDiffOneArgJacobianPrep,
453-
::AutoForwardDiff,
467+
backend::AutoForwardDiff,
454468
x,
455469
contexts::Vararg{Constant,C},
456470
) where {F,C}
457471
fc = with_contexts(f, contexts...)
458-
return jacobian!(jac, fc, x, prep.config)
472+
CHK = tag_type(backend) === Nothing
473+
return jacobian!(jac, fc, x, prep.config, Val(CHK))
459474
end
460475

461476
function DI.jacobian(
462477
f::F,
463478
prep::ForwardDiffOneArgJacobianPrep,
464-
::AutoForwardDiff,
479+
backend::AutoForwardDiff,
465480
x,
466481
contexts::Vararg{Constant,C},
467482
) where {F,C}
468483
fc = with_contexts(f, contexts...)
469-
return jacobian(fc, x, prep.config)
484+
CHK = tag_type(backend) === Nothing
485+
return jacobian(fc, x, prep.config, Val(CHK))
470486
end
471487

472488
## Second derivative
@@ -681,44 +697,56 @@ function DI.hessian!(
681697
f::F,
682698
hess,
683699
prep::ForwardDiffHessianPrep,
684-
::AutoForwardDiff,
700+
backend::AutoForwardDiff,
685701
x,
686702
contexts::Vararg{Constant,C},
687703
) where {F,C}
688704
fc = with_contexts(f, contexts...)
689-
return hessian!(hess, fc, x, prep.array_config)
705+
CHK = tag_type(backend) === Nothing
706+
return hessian!(hess, fc, x, prep.array_config, Val(CHK))
690707
end
691708

692709
function DI.hessian(
693-
f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C}
710+
f::F,
711+
prep::ForwardDiffHessianPrep,
712+
backend::AutoForwardDiff,
713+
x,
714+
contexts::Vararg{Constant,C},
694715
) where {F,C}
695716
fc = with_contexts(f, contexts...)
696-
return hessian(fc, x, prep.array_config)
717+
CHK = tag_type(backend) === Nothing
718+
return hessian(fc, x, prep.array_config, Val(CHK))
697719
end
698720

699721
function DI.value_gradient_and_hessian!(
700722
f::F,
701723
grad,
702724
hess,
703725
prep::ForwardDiffHessianPrep,
704-
::AutoForwardDiff,
726+
backend::AutoForwardDiff,
705727
x,
706728
contexts::Vararg{Constant,C},
707729
) where {F,C}
708730
fc = with_contexts(f, contexts...)
709731
result = DiffResult(one(eltype(x)), (grad, hess))
710-
result = hessian!(result, fc, x, prep.result_config)
732+
CHK = tag_type(backend) === Nothing
733+
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
711734
y = DR.value(result)
712735
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
713736
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
714737
return (y, grad, hess)
715738
end
716739

717740
function DI.value_gradient_and_hessian(
718-
f::F, prep::ForwardDiffHessianPrep, ::AutoForwardDiff, x, contexts::Vararg{Constant,C}
741+
f::F,
742+
prep::ForwardDiffHessianPrep,
743+
backend::AutoForwardDiff,
744+
x,
745+
contexts::Vararg{Constant,C},
719746
) where {F,C}
720747
fc = with_contexts(f, contexts...)
721748
result = HessianResult(x)
722-
result = hessian!(result, fc, x, prep.result_config)
749+
CHK = tag_type(backend) === Nothing
750+
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
723751
return (DR.value(result), DR.gradient(result), DR.hessian(result))
724752
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,14 @@ function DI.value_and_derivative(
181181
f!::F,
182182
y,
183183
prep::ForwardDiffTwoArgDerivativePrep,
184-
::AutoForwardDiff,
184+
backend::AutoForwardDiff,
185185
x,
186186
contexts::Vararg{Constant,C},
187187
) where {F,C}
188188
fc! = with_contexts(f!, contexts...)
189189
result = MutableDiffResult(y, (similar(y),))
190-
result = derivative!(result, fc!, y, x, prep.config)
190+
CHK = tag_type(backend) === Nothing
191+
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
191192
return DiffResults.value(result), DiffResults.derivative(result)
192193
end
193194

@@ -196,39 +197,42 @@ function DI.value_and_derivative!(
196197
y,
197198
der,
198199
prep::ForwardDiffTwoArgDerivativePrep,
199-
::AutoForwardDiff,
200+
backend::AutoForwardDiff,
200201
x,
201202
contexts::Vararg{Constant,C},
202203
) where {F,C}
203204
fc! = with_contexts(f!, contexts...)
204205
result = MutableDiffResult(y, (der,))
205-
result = derivative!(result, fc!, y, x, prep.config)
206+
CHK = tag_type(backend) === Nothing
207+
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
206208
return DiffResults.value(result), DiffResults.derivative(result)
207209
end
208210

209211
function DI.derivative(
210212
f!::F,
211213
y,
212214
prep::ForwardDiffTwoArgDerivativePrep,
213-
::AutoForwardDiff,
215+
backend::AutoForwardDiff,
214216
x,
215217
contexts::Vararg{Constant,C},
216218
) where {F,C}
217219
fc! = with_contexts(f!, contexts...)
218-
return derivative(fc!, y, x, prep.config)
220+
CHK = tag_type(backend) === Nothing
221+
return derivative(fc!, y, x, prep.config, Val(CHK))
219222
end
220223

221224
function DI.derivative!(
222225
f!::F,
223226
y,
224227
der,
225228
prep::ForwardDiffTwoArgDerivativePrep,
226-
::AutoForwardDiff,
229+
backend::AutoForwardDiff,
227230
x,
228231
contexts::Vararg{Constant,C},
229232
) where {F,C}
230233
fc! = with_contexts(f!, contexts...)
231-
return derivative!(der, fc!, y, x, prep.config)
234+
CHK = tag_type(backend) === Nothing
235+
return derivative!(der, fc!, y, x, prep.config, Val(CHK))
232236
end
233237

234238
## Jacobian
@@ -308,14 +312,15 @@ function DI.value_and_jacobian(
308312
f!::F,
309313
y,
310314
prep::ForwardDiffTwoArgJacobianPrep,
311-
::AutoForwardDiff,
315+
backend::AutoForwardDiff,
312316
x,
313317
contexts::Vararg{Constant,C},
314318
) where {F,C}
315319
fc! = with_contexts(f!, contexts...)
316320
jac = similar(y, length(y), length(x))
317321
result = MutableDiffResult(y, (jac,))
318-
result = jacobian!(result, fc!, y, x, prep.config)
322+
CHK = tag_type(backend) === Nothing
323+
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
319324
return DiffResults.value(result), DiffResults.jacobian(result)
320325
end
321326

@@ -324,37 +329,40 @@ function DI.value_and_jacobian!(
324329
y,
325330
jac,
326331
prep::ForwardDiffTwoArgJacobianPrep,
327-
::AutoForwardDiff,
332+
backend::AutoForwardDiff,
328333
x,
329334
contexts::Vararg{Constant,C},
330335
) where {F,C}
331336
fc! = with_contexts(f!, contexts...)
332337
result = MutableDiffResult(y, (jac,))
333-
result = jacobian!(result, fc!, y, x, prep.config)
338+
CHK = tag_type(backend) === Nothing
339+
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
334340
return DiffResults.value(result), DiffResults.jacobian(result)
335341
end
336342

337343
function DI.jacobian(
338344
f!::F,
339345
y,
340346
prep::ForwardDiffTwoArgJacobianPrep,
341-
::AutoForwardDiff,
347+
backend::AutoForwardDiff,
342348
x,
343349
contexts::Vararg{Constant,C},
344350
) where {F,C}
345351
fc! = with_contexts(f!, contexts...)
346-
return jacobian(fc!, y, x, prep.config)
352+
CHK = tag_type(backend) === Nothing
353+
return jacobian(fc!, y, x, prep.config, Val(CHK))
347354
end
348355

349356
function DI.jacobian!(
350357
f!::F,
351358
y,
352359
jac,
353360
prep::ForwardDiffTwoArgJacobianPrep,
354-
::AutoForwardDiff,
361+
backend::AutoForwardDiff,
355362
x,
356363
contexts::Vararg{Constant,C},
357364
) where {F,C}
358365
fc! = with_contexts(f!, contexts...)
359-
return jacobian!(jac, fc!, y, x, prep.config)
366+
CHK = tag_type(backend) === Nothing
367+
return jacobian!(jac, fc!, y, x, prep.config, Val(CHK))
360368
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ function get_tag(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksi
3131
return Tag(f, eltype(x))
3232
end
3333

34+
tag_type(::AutoForwardDiff{chunksize,T}) where {chunksize,T} = T
3435
tag_type(f::F, backend::AutoForwardDiff, x) where {F} = typeof(get_tag(f, backend, x))
3536

3637
function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T,B}

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ using Test
1010

1111
LOGGING = get(ENV, "CI", "false") == "false"
1212

13+
struct MyTag end
14+
1315
backends = [
14-
AutoForwardDiff(), AutoForwardDiff(; tag=:hello), AutoForwardDiff(; chunksize=5)
16+
AutoForwardDiff(),
17+
AutoForwardDiff(; chunksize=5),
18+
AutoForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)),
1519
]
1620

1721
for backend in backends
@@ -54,10 +58,10 @@ test_differentiation(
5458

5559
## Sparse
5660

57-
test_differentiation(MyAutoSparse.(backends[1:2]), default_scenarios(); logging=LOGGING);
61+
test_differentiation(MyAutoSparse(AutoForwardDiff()), default_scenarios(); logging=LOGGING);
5862

5963
test_differentiation(
60-
MyAutoSparse.(backends[1:2]),
64+
MyAutoSparse(AutoForwardDiff()),
6165
sparse_scenarios(; include_constantified=true);
6266
sparsity=true,
6367
logging=LOGGING,
@@ -78,7 +82,9 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
7882
excluded=[:hessian, :pullback], # TODO: figure this out
7983
logging=LOGGING,
8084
)
81-
@testset "$(row[:scenario])" for row in eachrow(data)
82-
@test row[:allocs] == 0
85+
@testset "Analyzing benchmark results" begin
86+
@testset "$(row[:scenario])" for row in eachrow(data)
87+
@test row[:allocs] == 0
88+
end
8389
end
8490
end;

0 commit comments

Comments
 (0)