-
Notifications
You must be signed in to change notification settings - Fork 38
feat: use traced_call when unrolling iterators and generators
#1642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0448561 to
5f0bd8f
Compare
traced_call when unrolling iterators and generators
aa8ffdd to
639eb7b
Compare
3320bf6 to
750f6a7
Compare
750f6a7 to
8e700b6
Compare
|
https://github.com/EnzymeAD/Reactant.jl/actions/runs/17690716713/job/50283550550?pr=1642#step:11:803 cant seem to repro this locally 😢 |
"builtin.module"() ({
"func.func"() <{function_type = (tensor<2xf64>, tensor<f64>) -> (tensor<2xf64>, tensor<f64>), sym_name = "##Base.Fix1{typeof(===), Reactant.TracedRArray{Float64, 1}}(===, TracedRArray{Float64,1N}(((:args, 1, 1), (:resargs, 1, 1)), size=(2,)))#2600", sym_visibility = "private"}> ({
^bb0(%arg2: tensor<2xf64> loc("arg1.x (path=(Symbol(\22##callarg#2601\22), 1, 2))"), %arg3: tensor<f64> loc("arg2 (path=(Symbol(\22##callarg#2601\22), 2))")):
"func.return"(%arg2, %arg3) : (tensor<2xf64>, tensor<f64>) -> () loc(#loc)
}) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
%0:2 = "func.call"(%14, %3) <{callee = @"##Base.Fix1{typeof(===), Reactant.TracedRArray{Float64, 1}}(===, TracedRArray{Float64,1N}(((:args, 1, 1), (:resargs, 1, 1)), size=(2,)))#2600"}> : (tensor<2xf64>, tensor<f64>) -> (tensor<2xf64>, tensor<f64>) loc(#loc13)
"func.func"() <{function_type = () -> (), sym_name = "##identity#2604", sym_visibility = "private"}> ({
"func.return"() : () -> () loc(#loc)
}) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
"func.call"() <{callee = @"##identity#2604"}> : () -> () loc(#loc13)
"func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##Base.Fix1{typeof(===), Reactant.TracedRNumber{Float64}}(===, TracedRNumber{Float64}(((:args, 1, 2), (:resargs, 1, 2), (Symbol(\22##callarg#2601\22), 2))))#2608", sym_visibility = "private"}> ({
^bb0(%arg1: tensor<f64> loc("arg1.x (path=(Symbol(\22##callarg#2609\22), 1, 2))")):
"func.return"(%arg1) : (tensor<f64>) -> () loc(#loc)
}) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
%1 = "func.call"(%0#1) <{callee = @"##Base.Fix1{typeof(===), Reactant.TracedRNumber{Float64}}(===, TracedRNumber{Float64}(((:args, 1, 2), (:resargs, 1, 2), (Symbol(\22##callarg#2601\22), 2))))#2608"}> : (tensor<f64>) -> tensor<f64> loc(#loc13)
"func.func"() <{function_type = () -> (), sym_name = "##identity#2612", sym_visibility = "private"}> ({
"func.return"() : () -> () loc(#loc)
}) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
"func.call"() <{callee = @"##identity#2612"}> : () -> () loc(#loc13)
"func.func"() <{function_type = (tensor<2xf64>) -> tensor<2xf64>, sym_name = "update!"}> ({
^bb0(%arg0: tensor<2xf64> loc("arg1.data (path=(:args, 1, 1))")):
%2 = "stablehlo.constant"() <{value = dense<2.700000e+00> : tensor<f64>}> : () -> tensor<f64> loc(#loc14)
%3 = "stablehlo.convert"(%2) : (tensor<f64>) -> tensor<f64> loc(#loc15)
%4 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc)
%5 = "stablehlo.broadcast_in_dim"(%3) <{broadcast_dimensions = array<i64>}> : (tensor<f64>) -> tensor<1xf64> loc(#loc16)
%6 = "stablehlo.constant"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32> loc(#loc14)
%7 = "stablehlo.convert"(%6) : (tensor<i32>) -> tensor<i32> loc(#loc15)
%8 = "stablehlo.constant"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32> loc(#loc14)
%9 = "stablehlo.convert"(%8) : (tensor<i32>) -> tensor<i32> loc(#loc15)
%10 = "stablehlo.subtract"(%7, %9) : (tensor<i32>, tensor<i32>) -> tensor<i32> loc(#loc17)
%11 = "stablehlo.dynamic_update_slice"(%4, %5, %10) : (tensor<2xf64>, tensor<1xf64>, tensor<i32>) -> tensor<2xf64> loc(#loc18)
%12 = "stablehlo.transpose"(%11) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc19)
%13 = "stablehlo.reshape"(%12) : (tensor<2xf64>) -> tensor<2xf64> loc(#loc20)
%14 = "stablehlo.transpose"(%13) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc19)
%15 = "stablehlo.transpose"(%0#0) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc)
"func.return"(%15) : (tensor<2xf64>) -> () loc(#loc)
}) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
}) : () -> () loc(#loc) |
|
Huh the call gets inserted in the module... |
|
it looks like we are messing up the insertion points somehow |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## ap/traced_call_v2 #1642 +/- ##
=====================================================
+ Coverage 68.58% 68.72% +0.13%
=====================================================
Files 103 103
Lines 11585 11596 +11
=====================================================
+ Hits 7946 7969 +23
+ Misses 3639 3627 -12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
* feat: better support for Base.Generators * feat: use traced_call when unrolling iterators and generators * fix: closure with call working * fix: try removing nospecialize * fix: use a looped version of any to avoid inference issues * fix: dont overlay inside compiler call
* feat: better support for Base.Generators * feat: use traced_call when unrolling iterators and generators * fix: closure with call working * fix: try removing nospecialize * fix: use a looped version of any to avoid inference issues * fix: dont overlay inside compiler call
not yet finished but tries to tame the codegen for
mapreducewhen we need to unroll the computationfixes #1616