Skip to content

Commit 95f34eb

Browse files
authored
[AutoDiff] Fix derivative for array literal with tuple_element_addr elts (#78355)
The `adjIndex` was not incremented due to missed `remapType`. Fixes #54214
1 parent 6c29d2f commit 95f34eb

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3284,7 +3284,7 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
32843284
auto adjSource = getAdjointBuffer(origBB, source);
32853285
if (!adjSource->getType().is<TupleType>())
32863286
return adjSource;
3287-
auto origTupleTy = source->getType().castTo<TupleType>();
3287+
auto origTupleTy = remapType(source->getType()).castTo<TupleType>();
32883288
unsigned adjIndex = 0;
32893289
for (unsigned i : range(teai->getFieldIndex())) {
32903290
if (getTangentSpace(

test/AutoDiff/validation-test/array.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,7 @@ ArrayAutoDiffTests.test("ArrayLiteralTuple") {
192192
return [tuple.0, tuple.1]
193193
}
194194
let pb = pullback(at: Float(3), 4, of: { tupleElementGeneric($0, $1) })
195-
// FIXME(TF-977): Fix incorrect derivative for array literal with
196-
// `tuple_element_addr` elements.
197-
// expectEqual((1, 1), pb(FloatArrayTan([1, 1])))
198-
expectEqual((0, 2), pb(FloatArrayTan([1, 1])))
195+
expectEqual((1, 1), pb(FloatArrayTan([1, 1])))
199196
}
200197
}
201198

0 commit comments

Comments
 (0)