diff --git a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift index ae4a34a852455..74b18fc5e8bfd 100644 --- a/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift @@ -105,9 +105,9 @@ import SILBridging private let verbose = false -private func log(_ message: @autoclosure () -> String) { +private func log(prefix: Bool = true, _ message: @autoclosure () -> String) { if verbose { - print("### \(message())") + debugLog(prefix: prefix, message()) } } @@ -128,47 +128,48 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special } var remainingSpecializationRounds = 5 - var callerModified = false repeat { + // TODO: Names here are pretty misleading. We are looking for a place where + // the pullback closure is created (so for `partial_apply` instruction). var callSites = gatherCallSites(in: function, context) + guard !callSites.isEmpty else { + return + } - if !callSites.isEmpty { - for callSite in callSites { - var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context) - - if !alreadyExists { - context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee) - } + for callSite in callSites { + var (specializedFunction, alreadyExists) = getOrCreateSpecializedFunction(basedOn: callSite, context) - rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context) + if !alreadyExists { + context.notifyNewFunction(function: specializedFunction, derivedFrom: callSite.applyCallee) } - var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in - callSite.closureArgDescriptors - .map { $0.closure } - .forEach { deadClosures.pushIfNotVisited($0) } - } + rewriteApplyInstruction(using: specializedFunction, callSite: callSite, context) + } - defer { - deadClosures.deinitialize() - } + var deadClosures: InstructionWorklist = callSites.reduce(into: InstructionWorklist(context)) { deadClosures, callSite in + callSite.closureArgDescriptors + .map { $0.closure } + .forEach { deadClosures.pushIfNotVisited($0) } + } - while let deadClosure = deadClosures.pop() { - let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction) - if isDeleted { - context.notifyInvalidatedStackNesting() - } - } + defer { + deadClosures.deinitialize() + } - if context.needFixStackNesting { - function.fixStackNesting(context) + while let deadClosure = deadClosures.pop() { + let isDeleted = context.tryDeleteDeadClosure(closure: deadClosure as! SingleValueInstruction) + if isDeleted { + context.notifyInvalidatedStackNesting() } } - callerModified = callSites.count > 0 + if context.needFixStackNesting { + function.fixStackNesting(context) + } + remainingSpecializationRounds -= 1 - } while callerModified && remainingSpecializationRounds > 0 + } while remainingSpecializationRounds > 0 } // =========== Top-level functions ========== // @@ -503,12 +504,6 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap: continue } - // Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78847 - // TODO: remove this if-statement once the underlying problem is fixed. - if callee.hasOwnership { - continue - } - if callee.isDefinedExternally { continue } @@ -779,13 +774,13 @@ private extension SpecializationCloner { let clonedRootClosure = builder.cloneRootClosure(representedBy: closureArgDesc, capturedArguments: clonedClosureArgs) - let (finalClonedReabstractedClosure, releasableClonedReabstractedClosures) = + let finalClonedReabstractedClosure = builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure, reabstractedClosure: callSite.appliedArgForClosure(at: closureArgDesc.closureArgIndex)!, origToClonedValueMap: origToClonedValueMap, self.context) - let allClonedReleasableClosures = [clonedRootClosure] + releasableClonedReabstractedClosures + let allClonedReleasableClosures = [ finalClonedReabstractedClosure ]; return (finalClonedReabstractedClosure, allClonedReleasableClosures) } @@ -935,10 +930,9 @@ private extension Builder { func cloneRootClosureReabstractions(rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value, origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext) - -> (finalClonedReabstractedClosure: SingleValueInstruction, releasableClonedReabstractedClosures: [PartialApplyInst]) + -> SingleValueInstruction { func inner(_ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value, - _ releasableClonedReabstractedClosures: inout [PartialApplyInst], _ origToClonedValueMap: inout [HashableValue: Value]) -> Value { switch reabstractedClosure { case let reabstractedClosure where reabstractedClosure == rootClosure: @@ -947,7 +941,7 @@ private extension Builder { case let cvt as ConvertFunctionInst: let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction, - &releasableClonedReabstractedClosures, &origToClonedValueMap) + &origToClonedValueMap) let reabstracted = self.createConvertFunction(originalFunction: toBeReabstracted, resultType: cvt.type, withoutActuallyEscaping: cvt.withoutActuallyEscaping) origToClonedValueMap[cvt] = reabstracted @@ -955,7 +949,7 @@ private extension Builder { case let cvt as ConvertEscapeToNoEscapeInst: let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction, - &releasableClonedReabstractedClosures, &origToClonedValueMap) + &origToClonedValueMap) let reabstracted = self.createConvertEscapeToNoEscape(originalFunction: toBeReabstracted, resultType: cvt.type, isLifetimeGuaranteed: true) origToClonedValueMap[cvt] = reabstracted @@ -963,7 +957,7 @@ private extension Builder { case let pai as PartialApplyInst: let toBeReabstracted = inner(rootClosure, clonedRootClosure, pai.arguments[0], - &releasableClonedReabstractedClosures, &origToClonedValueMap) + &origToClonedValueMap) guard let function = pai.referencedFunction else { log("Parent function of callSite: \(rootClosure.parentFunction)") @@ -978,13 +972,11 @@ private extension Builder { calleeConvention: pai.calleeConvention, hasUnknownResultIsolation: pai.hasUnknownResultIsolation, isOnStack: pai.isOnStack) - releasableClonedReabstractedClosures.append(reabstracted) origToClonedValueMap[pai] = reabstracted return reabstracted case let mdi as MarkDependenceInst: - let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &releasableClonedReabstractedClosures, - &origToClonedValueMap) + let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &origToClonedValueMap) let base = origToClonedValueMap[mdi.base]! let reabstracted = self.createMarkDependence(value: toBeReabstracted, base: base, kind: .Escaping) origToClonedValueMap[mdi] = reabstracted @@ -998,11 +990,10 @@ private extension Builder { } } - var releasableClonedReabstractedClosures: [PartialApplyInst] = [] var origToClonedValueMap = origToClonedValueMap let finalClonedReabstractedClosure = inner(rootClosure, clonedRootClosure, reabstractedClosure, - &releasableClonedReabstractedClosures, &origToClonedValueMap) - return (finalClonedReabstractedClosure as! SingleValueInstruction, releasableClonedReabstractedClosures) + &origToClonedValueMap) + return (finalClonedReabstractedClosure as! SingleValueInstruction) } func destroyPartialApply(pai: PartialApplyInst, _ context: FunctionPassContext){ diff --git a/test/AutoDiff/SILOptimizer/BuildingSimulation.swift b/test/AutoDiff/SILOptimizer/BuildingSimulation.swift new file mode 100644 index 0000000000000..b0b0dfd2f0d88 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/BuildingSimulation.swift @@ -0,0 +1,228 @@ +// RUN: %target-swift-frontend -emit-sil -verify -O %s | %FileCheck %s +// REQUIRES: swift_in_compiler + +import _Differentiation + +// Simulation parameters +let trials = 100 +let timesteps = 20 +let dTime: Float = 0.1 + +// Definitions +let π = Float.pi + +struct SimParams: Differentiable { + var tube: TubeType = .init() + var slab: SlabType = .init() + var quanta: QuantaType = .init() + var tank: TankType = .init() + var startingTemp: Float +} + +struct TubeType: Differentiable { + var tubeSpacing: Float = 0.50292 // meters + var diameter: Float = 0.019 // m (3/4") + var thickness: Float = 0.001588 // m (1/16") + var resistivity: Float = 2.43 // (K/W)m +} + +struct SlabType: Differentiable { + var temp: Float = 21.1111111 // °C + var area: Float = 100.0 // m^2 + var Cp: Float = 0.2 + var density: Float = 2242.58 // kg/m^3 + var thickness: Float = 0.101 // m +} + +struct QuantaType: Differentiable { + var power: Float = 0.0 // Watt + var temp: Float = 60.0 // °C + var flow: Float = 0.0006309 // m^3/sec + var density: Float = 1000.0 // kg/m^3 + var Cp: Float = 4180.0 // ws/(kg • K) +} + +struct TankType: Differentiable { + var temp: Float = 70.0 + var volume: Float = 0.0757082 + var Cp: Float = 4180.000 + var density: Float = 1000.000 + var mass: Float = 75.708 +} + +// Computations + +@differentiable(reverse) +func computeResistance(floor: SlabType, tube: TubeType, quanta _: QuantaType) -> Float { + let geometry_coeff: Float = 10.0 + // let f_coff = 0.3333333 + + let tubingSurfaceArea = (floor.area / tube.tubeSpacing) * π * tube.diameter + let resistance_abs = tube.resistivity * tube.thickness / tubingSurfaceArea + + let resistance_corrected = resistance_abs * geometry_coeff // * (quanta.flow * f_coff) + + return resistance_corrected +} + +struct QuantaAndPower: Differentiable { + var quanta: QuantaType + var power: Float +} + + +extension Differentiable { + /// Applies the given closure to the derivative of `self`. + /// + /// Returns `self` like an identity function. When the return value is used in + /// a context where it is differentiated with respect to, applies the given + /// closure to the derivative of the return value. + @inlinable + @differentiable(reverse, wrt: self) + func withDerivative(_: @escaping (inout TangentVector) -> Void) -> Self { + return self + } + + @inlinable + @derivative(of: withDerivative) + func _vjpWithDerivative( + _ body: @escaping (inout TangentVector) -> Void + ) -> (value: Self, pullback: (TangentVector) -> TangentVector) { + return (self, { grad in + var grad = grad + body(&grad) + return grad + }) + } +} + +@differentiable(reverse) +func computeLoadPower(floor: SlabType, tube: TubeType, quanta: QuantaType) -> QuantaAndPower { + let resistance_abs = computeResistance(floor: floor, tube: tube, quanta: quanta) + + let conductance: Float = 1 / resistance_abs + let dTemp = floor.temp - quanta.temp + let power = dTemp * conductance + + var updatedQuanta = quanta + updatedQuanta.power = power + let loadPower = -power + + return QuantaAndPower(quanta: updatedQuanta, power: loadPower) +} + +@differentiable(reverse) +func updateQuanta(quanta: QuantaType) -> QuantaType { + let workingVolume = (quanta.flow * dTime) + let workingMass = (workingVolume * quanta.density) + let workingEnergy = quanta.power * dTime + let TempRise = workingEnergy / quanta.Cp / workingMass + var updatedQuanta = quanta + updatedQuanta.temp = quanta.temp + TempRise + + updatedQuanta.power = 0 + return updatedQuanta +} + +@differentiable(reverse) +func updateBuildingModel(power: Float, floor: SlabType) -> SlabType { + var updatedFloor = floor + + let floorVolume = floor.area * floor.thickness + let floorMass = floorVolume * floor.density + + updatedFloor.temp = floor.temp + ((power * dTime) / floor.Cp / floorMass) + return updatedFloor +} + +struct TankAndQuanta: Differentiable { + var tank: TankType + var quanta: QuantaType +} + +@differentiable(reverse) +func updateSourceTank(store: TankType, quanta: QuantaType) -> TankAndQuanta { + var updatedStore = store + var updatedQuanta = quanta + + let massPerTime = quanta.flow * quanta.density + let dTemp = store.temp - quanta.temp + let power = dTemp * massPerTime * quanta.Cp + + updatedQuanta.power = power + + let tankMass = store.volume * store.density + let TempRise = (power * dTime) / store.Cp / tankMass + updatedStore.temp = store.temp + TempRise + + return TankAndQuanta(tank: updatedStore, quanta: updatedQuanta) +} + +var simParams = SimParams(startingTemp: 33.3) + +@differentiable(reverse) +@inlinable public func absDifferentiable(_ value: Float) -> Float { + if value < 0 { + return -value + } + return value +} + +func lossCalc(pred: Float, gt: Float) -> Float { + let diff = pred - gt + return absDifferentiable(diff) +} + +// Simulations + +// Ensure things are properly specialized +// CHECK-LABEL: sil hidden @$s18BuildingSimulation8simulate9simParamsSfAA03SimE0V_tFTJrSpSr +// CHECK: function_ref specialized pullback of updateSourceTank(store:quanta:) +// CHECK-NOT: function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:) +// CHECK-NOT: function_ref pullback of updateQuanta(quanta:) +// CHECK: function_ref specialized pullback of updateQuanta(quanta:) +@differentiable(reverse) +func simulate(simParams: SimParams) -> Float { + let pexTube = simParams.tube + var slab = simParams.slab + var tank = simParams.tank + var quanta = simParams.quanta + + slab.temp = simParams.startingTemp + for _ in 0 ..< timesteps { + let tankAndQuanta = updateSourceTank(store: tank, quanta: quanta) + tank = tankAndQuanta.tank + quanta = tankAndQuanta.quanta + + quanta = updateQuanta(quanta: quanta) + + let quantaAndPower = computeLoadPower(floor: slab, tube: pexTube, quanta: quanta) + quanta = quantaAndPower.quanta + let powerToBuilding = quantaAndPower.power + quanta = updateQuanta(quanta: quanta) + + slab = updateBuildingModel(power: powerToBuilding, floor: slab) + } + return slab.temp +} + +var blackHole: Any? +@inline(never) +func dontLetTheCompilerOptimizeThisAway(_ x: T) { + blackHole = x +} + +@differentiable(reverse) +func fullPipe(simParams: SimParams) -> Float { + let pred = simulate(simParams: simParams) + let loss = lossCalc(pred: pred, gt: 27.344767) + return loss +} + +for _ in 0 ..< trials { + let forwardOnly = fullPipe(simParams: simParams) + dontLetTheCompilerOptimizeThisAway(forwardOnly) + + let grad = gradient(at: simParams, of: fullPipe) + dontLetTheCompilerOptimizeThisAway(grad) +}