From 60f684ce5e2c31284b041f13922c57ca151a8270 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 21 Aug 2025 21:17:47 +0530 Subject: [PATCH 1/3] chore: prefer ArrayPartitionStyle to preserve nesting structure --- src/array_partition.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 89261f4b..b9b436bc 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -336,9 +336,10 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle{Style}, } ArrayPartitionStyle{Style}() end -function Broadcast.BroadcastStyle(::ArrayPartitionStyle, - ::Broadcast.DefaultArrayStyle{N}) where {N} - Broadcast.DefaultArrayStyle{N}() +function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, + ::Broadcast.DefaultArrayStyle{N}) where {AStyle, N} + pick = Broadcast.BroadcastStyle(AStyle(), Broadcast.DefaultArrayStyle{N}()) + ArrayPartitionStyle(pick, Val(N)) end combine_styles(::Type{Tuple{}}) = Broadcast.DefaultArrayStyle{0}() From c88e61ba6858667b4c7377c3a5c12aa1c681b59b Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 21 Aug 2025 21:22:22 +0530 Subject: [PATCH 2/3] test: broadcasting preserves nested types --- test/adjoints.jl | 4 ++++ test/basic_indexing.jl | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/test/adjoints.jl b/test/adjoints.jl index af2abd42..1682ded1 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -92,3 +92,7 @@ loss(x) VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) @test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) @test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) + +x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) +g = Zygote.gradient(norm, x)[1] +@test g isa typeof(x) \ No newline at end of file diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 442e761f..24b5c863 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -280,3 +280,9 @@ x = VectorOfArray(StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))) y = 2 * x @. x = y @test all(all.(y .== x)) + + +x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) +@test (x .* 1.2) isa ArrayPartition{<:Any, <:ArrayPartition} + +g = Zygote.gradient(norm, x)[1] \ No newline at end of file From 1b3238fdf90f2b198ad8697023f51d83e36c7efa Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 25 Aug 2025 19:06:11 +0530 Subject: [PATCH 3/3] test: improve variable names; rm unused loc --- test/basic_indexing.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 24b5c863..f2e0dc59 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -282,7 +282,5 @@ y = 2 * x @test all(all.(y .== x)) -x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) -@test (x .* 1.2) isa ArrayPartition{<:Any, <:ArrayPartition} - -g = Zygote.gradient(norm, x)[1] \ No newline at end of file +x_ap = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) +@test (x_ap .* 1.2) isa typeof(x_ap)