Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ function set_reactant_abi(
)
end


current_interpreter = Ref{Enzyme.Compiler.Interpreter.EnzymeInterpreter{typeof(Reactant.set_reactant_abi)}}()
@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE
struct ReactantCacheToken end

function ReactantInterpreter(; world::UInt=Base.get_world_counter())
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
Expand All @@ -108,7 +110,7 @@ else
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
Expand Down
2 changes: 1 addition & 1 deletion src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end

# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
function precompilation_supported()
return (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-")
return false && (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-")
end

if Reactant_jll.is_available()
Expand Down
3 changes: 2 additions & 1 deletion src/TracedRange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ function Base._reshape(parent::TracedUnitRange, dims::Dims)
return Base.__reshape((parent, IndexStyle(parent)), dims)
end

function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T}
#=function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T}
return TracedUnitRange(start, stop)
end
=#
function (C::Base.Colon)(start::TracedRNumber{T}, stop::T) where {T}
return C(start, TracedRNumber{T}(stop))
end
Expand Down
6 changes: 6 additions & 0 deletions src/auto_cf/AutoCF.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using Debugger
include("new_inference.jl")
include("code_info_mut.jl")
include("code_ir_utils.jl")
include("mlir_utils.jl")
include("ir_control_flow_transform.jl")
83 changes: 83 additions & 0 deletions src/auto_cf/analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
@enum UpgradeSlot NoUpgrade UpgradeLocally UpgradeDefinition UpgradeDefinitionGlobal

@enum State Traced Upgraded Maybe NotTraced

@enum LoopKind For While

mutable struct LoopStructure
kind::LoopKind
accus::Tuple
header_bb::Int
latch_bb::Int
terminal_bb::Int
body_bbs::Set{Int}
state::State
end

struct IfStructure
ssa_cond
header_bb::Int
terminal_bb::Int
true_bbs::Set{Int}
false_bbs::Set{Int}
owned_true_bbs::Set{Int}
owned_false_bbs::Set{Int}
legalize::Ref{Bool} #inform that the if traced GotoIfNot can pass type inference
unbalanced_slots::Set{Core.SlotNumber}
end

mutable struct SlotAnalysis
slot_stmt_def::Vector{Integer} #0 for argument
slot_bb_usage::Vector{Set{Int}}
end


CFStructure = Union{IfStructure,LoopStructure}
mutable struct Tree
node::Union{Nothing,Base.uniontypes(CFStructure)...}
children::Vector{Tree}
parent::Ref{Tree}
end

Base.isempty(tree::Tree) = isnothing(tree.node) && length(tree.children) == 0

Base.show(io::IO, t::Tree) = begin
Base.print(io, '(')
Base.show(io, t.node)
Base.print(io, ',')
Base.show(io, t.children)
Base.print(io, ')')
end

mutable struct Analysis
tree::Tree
domtree::Union{Nothing,Vector{CC.DomTreeNode}}
postdomtree::Union{Nothing,Vector{CC.DomTreeNode}}
slotanalysis::Union{Nothing,SlotAnalysis}
pending_tree::Union{Nothing,Tree}
end


#leak each argument to a global variable
macro lk(args...)
quote
$([:(
let val = $(esc(p))
global $(esc(p)) = val
end
) for p in args]...)
end
end

MethodInstanceKey = Vector{Type}
function mi_key(mi::Core.MethodInstance)
return collect(Base.unwrap_unionall(mi.specTypes).parameters)
end
@kwdef struct MetaData
traced_tree_map::Dict{MethodInstanceKey,Tree} = Dict()
end

meta = Ref(MetaData())
function get_meta(_::Reactant.ReactantInterp)::MetaData
meta[]
end
148 changes: 148 additions & 0 deletions src/auto_cf/code_info_mut.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
struct ShiftedSSA
e::Int
end

struct ShiftedCF
e::Int
end

"""
offset_stmt!(stmt, index, next_bb=true)

Recursively offsets SSA and control flow references in a statement `stmt` by `index`.
Used to update SSA and control flow when inserting instructions.
"""
function offset_stmt!(stmt, index, next_bb = true)
if stmt isa Expr
Expr(
stmt.head, (offset_stmt!(a, index) for a in stmt.args)...)
elseif stmt isa Core.ReturnNode
Core.ReturnNode(offset_stmt!(stmt.val, index))
elseif stmt isa Core.SSAValue
Core.SSAValue(offset_stmt!(ShiftedSSA(stmt.id), index))
elseif stmt isa Core.GotoIfNot
Core.GotoIfNot(offset_stmt!(stmt.cond, index), offset_stmt!(ShiftedCF(stmt.dest), index, next_bb))
elseif stmt isa Core.GotoNode
Core.GotoNode(offset_stmt!(ShiftedCF(stmt.label), index, next_bb))
elseif stmt isa ShiftedSSA
stmt.e + (stmt.e < index ? 0 : 1)
elseif stmt isa ShiftedCF
stmt.e + (stmt.e < index + (next_bb ? 1 : 0) ? 0 : 1)
else
stmt
end
end

"""
add_instruction!(frame, index, stmt; type=CC.NotFound(), next_bb=true)

Insert `stmt` into `frame` after `index`, updating frame field accordingly
Returns the new SSAValue for the inserted instruction.
"""
function add_instruction!(frame, index, stmt; type=CC.NotFound(), next_bb = true)
add_instruction!(frame.src, index, stmt; type, next_bb)
frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: more fine graine change here
insert!(frame.stmt_info, index + 1, CC.NoCallInfo())
insert!(frame.stmt_edges, index + 1, nothing)
insert!(frame.handler_at, index + 1, (0,0))
frame.cfg = CC.compute_basic_blocks(frame.src.code)
Core.SSAValue(index + 1)
end

"""
modify_instruction!(frame, index, stmt)

Modify the instruction at `index` in `frame` to `stmt` and update `frame` SSA value uses.
"""
function modify_instruction!(frame, index, stmt)
frame.src.code[index] = stmt
frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: refine this
end

"""
add_instruction!(ir::CC.CodeInfo, index, stmt; type=CC.NotFound(), next_bb=true)

Insert `stmt` into `ir.code` after `index`, offsetting SSA and control flow in all instructions after the insertion point.
"""
function add_instruction!(ir::CC.CodeInfo, index, stmt; type=CC.NotFound(), next_bb=true)
for (i, c) in enumerate(ir.code)
ir.code[i] = offset_stmt!(c, index + 1, next_bb)
end
insert!(ir.code, index + 1, stmt)
insert!(ir.codelocs, index + 1, 0)
insert!(ir.ssaflags, index + 1, 0x00000000)
if ir.ssavaluetypes isa Int
ir.ssavaluetypes = ir.ssavaluetypes + 1
else
insert!(ir.ssavaluetypes, index + 1, type)
end
end

"""
create_slot!(ir::CC.CodeInfo)::Core.SlotNumber

Create a new slot in `ir` and return its SlotNumber.
"""
function create_slot!(ir::CC.CodeInfo)::Core.SlotNumber
push!(ir.slotflags, 0x00)
push!(ir.slotnames, Symbol(""))
Core.SlotNumber(length(ir.slotflags))
end

"""
create_slot!(frame)::Core.SlotNumber

Create a new slot in `frame` and return its SlotNumber.
"""
function create_slot!(frame)::Core.SlotNumber
push!(frame.slottypes, Union{})
for s in frame.bb_vartables
isnothing(s) && continue
push!(s, CC.VarState(Union{}, true))
end
create_slot!(frame.src)
end

"""
add_slot_change!(ir::CC.CodeInfo, index, old_slot::Int)
add_slot_change!(ir::CC.CodeInfo, index, old_slot::Core.SlotNumber)

Add a slot change at `index` in `ir`, upgrading `old_slot` to a new slot and inserting the upgrade instruction.
"""
add_slot_change!(ir::CC.CodeInfo, index, old_slot::Int) = add_slot_change!(ir, index, Core.SlotNumber(old_slot))

function add_slot_change!(ir::CC.CodeInfo, index, old_slot::Core.SlotNumber)
push!(ir.slotflags, 0x00)
push!(ir.slotnames, Symbol(""))
new_slot = Core.SlotNumber(length(ir.slotflags))
add_instruction!(frame, index, Expr(:(=), new_slot, Expr(:call, GlobalRef(@__MODULE__, :upgrade), old_slot)))
update_ir_new_slot(ir, index, old_slot, new_slot)
end

"""
update_ir_new_slot(ir, index, old_slot, new_slot)

Replace all occurrences of `old_slot` with `new_slot` in `ir.code` after `index`.
"""
function update_ir_new_slot(ir, index, old_slot, new_slot)
for i in index+2:length(ir.code) #TODO: probably need to refine this
ir.code[i] = replace_slot_stmt(ir.code[i], old_slot, new_slot)
end
end

"""
replace_slot_stmt(stmt, old_slot, new_slot)

Recursively replace `old_slot` with `new_slot` in a statement `stmt`.
"""
function replace_slot_stmt(stmt, old_slot, new_slot)
if stmt isa Core.NewvarNode
stmt
elseif stmt isa Expr
Expr(stmt.head, (replace_slot_stmt(e, old_slot, new_slot) for e in stmt.args)...)
elseif stmt isa Core.SlotNumber
stmt == old_slot ? new_slot : stmt
else
stmt
end
end
Loading
Loading