Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Sep 7, 2025

not yet finished but tries to tame the codegen for mapreduce when we need to unroll the computation

fixes #1616

@avik-pal avik-pal requested a review from jumerckx September 7, 2025 20:02
@avik-pal avik-pal force-pushed the ap/generators branch 2 times, most recently from 0448561 to 5f0bd8f Compare September 7, 2025 20:07
@avik-pal avik-pal changed the title feat: use traced_call when unrolling iterators and generators feat: use traced_call when unrolling iterators and generators Sep 7, 2025
@avik-pal avik-pal force-pushed the ap/generators branch 3 times, most recently from aa8ffdd to 639eb7b Compare September 12, 2025 20:13
avik-pal added a commit that referenced this pull request Sep 12, 2025
avik-pal added a commit that referenced this pull request Sep 12, 2025
avik-pal added a commit that referenced this pull request Sep 13, 2025
avik-pal added a commit that referenced this pull request Sep 13, 2025
@avik-pal avik-pal changed the base branch from main to ap/traced_call_v2 September 13, 2025 02:48
@avik-pal
Copy link
Collaborator Author

@avik-pal
Copy link
Collaborator Author

"builtin.module"() ({
  "func.func"() <{function_type = (tensor<2xf64>, tensor<f64>) -> (tensor<2xf64>, tensor<f64>), sym_name = "##Base.Fix1{typeof(===), Reactant.TracedRArray{Float64, 1}}(===, TracedRArray{Float64,1N}(((:args, 1, 1), (:resargs, 1, 1)), size=(2,)))#2600", sym_visibility = "private"}> ({
  ^bb0(%arg2: tensor<2xf64> loc("arg1.x (path=(Symbol(\22##callarg#2601\22), 1, 2))"), %arg3: tensor<f64> loc("arg2 (path=(Symbol(\22##callarg#2601\22), 2))")):
    "func.return"(%arg2, %arg3) : (tensor<2xf64>, tensor<f64>) -> () loc(#loc)
  }) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
  %0:2 = "func.call"(%14, %3) <{callee = @"##Base.Fix1{typeof(===), Reactant.TracedRArray{Float64, 1}}(===, TracedRArray{Float64,1N}(((:args, 1, 1), (:resargs, 1, 1)), size=(2,)))#2600"}> : (tensor<2xf64>, tensor<f64>) -> (tensor<2xf64>, tensor<f64>) loc(#loc13)
  "func.func"() <{function_type = () -> (), sym_name = "##identity#2604", sym_visibility = "private"}> ({
    "func.return"() : () -> () loc(#loc)
  }) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
  "func.call"() <{callee = @"##identity#2604"}> : () -> () loc(#loc13)
  "func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##Base.Fix1{typeof(===), Reactant.TracedRNumber{Float64}}(===, TracedRNumber{Float64}(((:args, 1, 2), (:resargs, 1, 2), (Symbol(\22##callarg#2601\22), 2))))#2608", sym_visibility = "private"}> ({
  ^bb0(%arg1: tensor<f64> loc("arg1.x (path=(Symbol(\22##callarg#2609\22), 1, 2))")):
    "func.return"(%arg1) : (tensor<f64>) -> () loc(#loc)
  }) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
  %1 = "func.call"(%0#1) <{callee = @"##Base.Fix1{typeof(===), Reactant.TracedRNumber{Float64}}(===, TracedRNumber{Float64}(((:args, 1, 2), (:resargs, 1, 2), (Symbol(\22##callarg#2601\22), 2))))#2608"}> : (tensor<f64>) -> tensor<f64> loc(#loc13)
  "func.func"() <{function_type = () -> (), sym_name = "##identity#2612", sym_visibility = "private"}> ({
    "func.return"() : () -> () loc(#loc)
  }) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
  "func.call"() <{callee = @"##identity#2612"}> : () -> () loc(#loc13)
  "func.func"() <{function_type = (tensor<2xf64>) -> tensor<2xf64>, sym_name = "update!"}> ({
  ^bb0(%arg0: tensor<2xf64> loc("arg1.data (path=(:args, 1, 1))")):
    %2 = "stablehlo.constant"() <{value = dense<2.700000e+00> : tensor<f64>}> : () -> tensor<f64> loc(#loc14)
    %3 = "stablehlo.convert"(%2) : (tensor<f64>) -> tensor<f64> loc(#loc15)
    %4 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc)
    %5 = "stablehlo.broadcast_in_dim"(%3) <{broadcast_dimensions = array<i64>}> : (tensor<f64>) -> tensor<1xf64> loc(#loc16)
    %6 = "stablehlo.constant"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32> loc(#loc14)
    %7 = "stablehlo.convert"(%6) : (tensor<i32>) -> tensor<i32> loc(#loc15)
    %8 = "stablehlo.constant"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32> loc(#loc14)
    %9 = "stablehlo.convert"(%8) : (tensor<i32>) -> tensor<i32> loc(#loc15)
    %10 = "stablehlo.subtract"(%7, %9) : (tensor<i32>, tensor<i32>) -> tensor<i32> loc(#loc17)
    %11 = "stablehlo.dynamic_update_slice"(%4, %5, %10) : (tensor<2xf64>, tensor<1xf64>, tensor<i32>) -> tensor<2xf64> loc(#loc18)
    %12 = "stablehlo.transpose"(%11) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc19)
    %13 = "stablehlo.reshape"(%12) : (tensor<2xf64>) -> tensor<2xf64> loc(#loc20)
    %14 = "stablehlo.transpose"(%13) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc19)
    %15 = "stablehlo.transpose"(%0#0) <{permutation = array<i64: 0>}> : (tensor<2xf64>) -> tensor<2xf64> loc(#loc)
    "func.return"(%15) : (tensor<2xf64>) -> () loc(#loc)
  }) {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} : () -> () loc(#loc)
}) : () -> () loc(#loc)

@avik-pal
Copy link
Collaborator Author

Huh the call gets inserted in the module...

@avik-pal
Copy link
Collaborator Author

it looks like we are messing up the insertion points somehow

@avik-pal avik-pal marked this pull request as ready for review September 13, 2025 04:11
@codecov
Copy link

codecov bot commented Sep 13, 2025

Codecov Report

❌ Patch coverage is 93.75000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 68.72%. Comparing base (9549a6e) to head (4f1524b).
⚠️ Report is 1 commits behind head on ap/traced_call_v2.

Files with missing lines Patch % Lines
src/Reactant.jl 92.30% 1 Missing ⚠️
Additional details and impacted files
@@                  Coverage Diff                  @@
##           ap/traced_call_v2    #1642      +/-   ##
=====================================================
+ Coverage              68.58%   68.72%   +0.13%     
=====================================================
  Files                    103      103              
  Lines                  11585    11596      +11     
=====================================================
+ Hits                    7946     7969      +23     
+ Misses                  3639     3627      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@avik-pal avik-pal requested a review from wsmoses September 13, 2025 13:35
@avik-pal avik-pal merged commit c038c6c into ap/traced_call_v2 Sep 13, 2025
66 of 68 checks passed
@avik-pal avik-pal deleted the ap/generators branch September 13, 2025 14:13
avik-pal added a commit that referenced this pull request Sep 15, 2025
avik-pal added a commit that referenced this pull request Sep 15, 2025
* feat: better support for Base.Generators

* feat: use traced_call when unrolling iterators and generators

* fix: closure with call working

* fix: try removing nospecialize

* fix: use a looped version of any to avoid inference issues

* fix: dont overlay inside compiler call
avik-pal added a commit that referenced this pull request Sep 15, 2025
* feat: better support for Base.Generators

* feat: use traced_call when unrolling iterators and generators

* fix: closure with call working

* fix: try removing nospecialize

* fix: use a looped version of any to avoid inference issues

* fix: dont overlay inside compiler call
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Could not find unique name for Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar when compiling function with lots of matrix multiplications

3 participants