|
| 1 | +// Copyright 2023 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package inlheur |
| 6 | + |
| 7 | +import ( |
| 8 | + "cmd/compile/internal/ir" |
| 9 | + "fmt" |
| 10 | + "os" |
| 11 | +) |
| 12 | + |
| 13 | +// paramsAnalyzer holds state information for the phase that computes |
| 14 | +// flags for a Go functions parameters, for use in inline heuristics. |
| 15 | +// Note that the params slice below includes entries for blanks. |
| 16 | +type paramsAnalyzer struct { |
| 17 | + fname string |
| 18 | + values []ParamPropBits |
| 19 | + params []*ir.Name |
| 20 | + top []bool |
| 21 | + *condLevelTracker |
| 22 | +} |
| 23 | + |
| 24 | +// dclParams returns a slice containing the non-blank, named params |
| 25 | +// for the specific function (plus rcvr as well if applicable) in |
| 26 | +// declaration order. |
| 27 | +func dclParams(fn *ir.Func) []*ir.Name { |
| 28 | + params := []*ir.Name{} |
| 29 | + for _, n := range fn.Dcl { |
| 30 | + if n.Op() != ir.ONAME { |
| 31 | + continue |
| 32 | + } |
| 33 | + if n.Class != ir.PPARAM { |
| 34 | + continue |
| 35 | + } |
| 36 | + params = append(params, n) |
| 37 | + } |
| 38 | + return params |
| 39 | +} |
| 40 | + |
| 41 | +// getParams returns an *ir.Name slice containing all params for the |
| 42 | +// function (plus rcvr as well if applicable). Note that this slice |
| 43 | +// includes entries for blanks; entries in the returned slice corresponding |
| 44 | +// to blanks or unnamed params will be nil. |
| 45 | +func getParams(fn *ir.Func) []*ir.Name { |
| 46 | + dclparms := dclParams(fn) |
| 47 | + dclidx := 0 |
| 48 | + recvrParms := fn.Type().RecvParams() |
| 49 | + params := make([]*ir.Name, len(recvrParms)) |
| 50 | + for i := range recvrParms { |
| 51 | + var v *ir.Name |
| 52 | + if recvrParms[i].Sym != nil && |
| 53 | + !recvrParms[i].Sym.IsBlank() { |
| 54 | + v = dclparms[dclidx] |
| 55 | + dclidx++ |
| 56 | + } |
| 57 | + params[i] = v |
| 58 | + } |
| 59 | + return params |
| 60 | +} |
| 61 | + |
| 62 | +func makeParamsAnalyzer(fn *ir.Func) *paramsAnalyzer { |
| 63 | + params := getParams(fn) // includes receiver if applicable |
| 64 | + vals := make([]ParamPropBits, len(params)) |
| 65 | + top := make([]bool, len(params)) |
| 66 | + for i, pn := range params { |
| 67 | + if pn == nil { |
| 68 | + continue |
| 69 | + } |
| 70 | + pt := pn.Type() |
| 71 | + if !pt.IsScalar() && !pt.HasNil() { |
| 72 | + // existing properties not applicable here (for things |
| 73 | + // like structs, arrays, slices, etc). |
| 74 | + continue |
| 75 | + } |
| 76 | + // If param is reassigned, skip it. |
| 77 | + if ir.Reassigned(pn) { |
| 78 | + continue |
| 79 | + } |
| 80 | + top[i] = true |
| 81 | + } |
| 82 | + |
| 83 | + if debugTrace&debugTraceParams != 0 { |
| 84 | + fmt.Fprintf(os.Stderr, "=-= param analysis of func %v:\n", |
| 85 | + fn.Sym().Name) |
| 86 | + for i := range vals { |
| 87 | + n := "_" |
| 88 | + if params[i] != nil { |
| 89 | + n = params[i].Sym().String() |
| 90 | + } |
| 91 | + fmt.Fprintf(os.Stderr, "=-= %d: %q %s\n", |
| 92 | + i, n, vals[i].String()) |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + return ¶msAnalyzer{ |
| 97 | + fname: fn.Sym().Name, |
| 98 | + values: vals, |
| 99 | + params: params, |
| 100 | + top: top, |
| 101 | + condLevelTracker: new(condLevelTracker), |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +func (pa *paramsAnalyzer) setResults(fp *FuncProps) { |
| 106 | + fp.ParamFlags = pa.values |
| 107 | +} |
| 108 | + |
| 109 | +// paramsAnalyzer invokes function 'testf' on the specified expression |
| 110 | +// 'x' for each parameter, and if the result is TRUE, or's 'flag' into |
| 111 | +// the flags for that param. |
| 112 | +func (pa *paramsAnalyzer) checkParams(x ir.Node, flag ParamPropBits, mayflag ParamPropBits, testf func(x ir.Node, param *ir.Name) bool) { |
| 113 | + for idx, p := range pa.params { |
| 114 | + if !pa.top[idx] && pa.values[idx] == ParamNoInfo { |
| 115 | + continue |
| 116 | + } |
| 117 | + result := testf(x, p) |
| 118 | + if debugTrace&debugTraceParams != 0 { |
| 119 | + fmt.Fprintf(os.Stderr, "=-= test expr %v param %s result=%v flag=%s\n", x, p.Sym().Name, result, flag.String()) |
| 120 | + } |
| 121 | + if result { |
| 122 | + v := flag |
| 123 | + if pa.condLevel != 0 { |
| 124 | + v = mayflag |
| 125 | + } |
| 126 | + pa.values[idx] |= v |
| 127 | + pa.top[idx] = false |
| 128 | + } |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +// foldCheckParams checks expression 'x' (an 'if' condition or |
| 133 | +// 'switch' stmt expr) to see if the expr would fold away if a |
| 134 | +// specific parameter had a constant value. |
| 135 | +func (pa *paramsAnalyzer) foldCheckParams(x ir.Node) { |
| 136 | + pa.checkParams(x, ParamFeedsIfOrSwitch, ParamMayFeedIfOrSwitch, |
| 137 | + func(x ir.Node, p *ir.Name) bool { |
| 138 | + return ShouldFoldIfNameConstant(x, []*ir.Name{p}) |
| 139 | + }) |
| 140 | +} |
| 141 | + |
| 142 | +// callCheckParams examines the target of call expression 'ce' to see |
| 143 | +// if it is making a call to the value passed in for some parameter. |
| 144 | +func (pa *paramsAnalyzer) callCheckParams(ce *ir.CallExpr) { |
| 145 | + switch ce.Op() { |
| 146 | + case ir.OCALLINTER: |
| 147 | + if ce.Op() != ir.OCALLINTER { |
| 148 | + return |
| 149 | + } |
| 150 | + sel := ce.X.(*ir.SelectorExpr) |
| 151 | + r := ir.StaticValue(sel.X) |
| 152 | + if r.Op() != ir.ONAME { |
| 153 | + return |
| 154 | + } |
| 155 | + name := r.(*ir.Name) |
| 156 | + if name.Class != ir.PPARAM { |
| 157 | + return |
| 158 | + } |
| 159 | + pa.checkParams(r, ParamFeedsInterfaceMethodCall, |
| 160 | + ParamMayFeedInterfaceMethodCall, |
| 161 | + func(x ir.Node, p *ir.Name) bool { |
| 162 | + name := x.(*ir.Name) |
| 163 | + return name == p |
| 164 | + }) |
| 165 | + case ir.OCALLFUNC: |
| 166 | + if ce.X.Op() != ir.ONAME { |
| 167 | + return |
| 168 | + } |
| 169 | + called := ir.StaticValue(ce.X) |
| 170 | + if called.Op() != ir.ONAME { |
| 171 | + return |
| 172 | + } |
| 173 | + name := called.(*ir.Name) |
| 174 | + if name.Class != ir.PPARAM { |
| 175 | + return |
| 176 | + } |
| 177 | + pa.checkParams(called, ParamFeedsIndirectCall, |
| 178 | + ParamMayFeedIndirectCall, |
| 179 | + func(x ir.Node, p *ir.Name) bool { |
| 180 | + name := x.(*ir.Name) |
| 181 | + return name == p |
| 182 | + }) |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +func (pa *paramsAnalyzer) nodeVisitPost(n ir.Node) { |
| 187 | + if len(pa.values) == 0 { |
| 188 | + return |
| 189 | + } |
| 190 | + pa.condLevelTracker.post(n) |
| 191 | + switch n.Op() { |
| 192 | + case ir.OCALLFUNC: |
| 193 | + ce := n.(*ir.CallExpr) |
| 194 | + pa.callCheckParams(ce) |
| 195 | + case ir.OCALLINTER: |
| 196 | + ce := n.(*ir.CallExpr) |
| 197 | + pa.callCheckParams(ce) |
| 198 | + case ir.OIF: |
| 199 | + ifst := n.(*ir.IfStmt) |
| 200 | + pa.foldCheckParams(ifst.Cond) |
| 201 | + case ir.OSWITCH: |
| 202 | + swst := n.(*ir.SwitchStmt) |
| 203 | + if swst.Tag != nil { |
| 204 | + pa.foldCheckParams(swst.Tag) |
| 205 | + } |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +func (pa *paramsAnalyzer) nodeVisitPre(n ir.Node) { |
| 210 | + if len(pa.values) == 0 { |
| 211 | + return |
| 212 | + } |
| 213 | + pa.condLevelTracker.pre(n) |
| 214 | +} |
| 215 | + |
| 216 | +// condLevelTracker helps keeps track very roughly of "level of conditional |
| 217 | +// nesting", e.g. how many "if" statements you have to go through to |
| 218 | +// get to the point where a given stmt executes. Example: |
| 219 | +// |
| 220 | +// cond nesting level |
| 221 | +// func foo() { |
| 222 | +// G = 1 0 |
| 223 | +// if x < 10 { 0 |
| 224 | +// if y < 10 { 1 |
| 225 | +// G = 0 2 |
| 226 | +// } |
| 227 | +// } |
| 228 | +// } |
| 229 | +// |
| 230 | +// The intent here is to provide some sort of very abstract relative |
| 231 | +// hotness metric, e.g. "G = 1" above is expected to be executed more |
| 232 | +// often than "G = 0" (in the aggregate, across large numbers of |
| 233 | +// functions). |
| 234 | +type condLevelTracker struct { |
| 235 | + condLevel int |
| 236 | +} |
| 237 | + |
| 238 | +func (c *condLevelTracker) pre(n ir.Node) { |
| 239 | + // Increment level of "conditional testing" if we see |
| 240 | + // an "if" or switch statement, and decrement if in |
| 241 | + // a loop. |
| 242 | + switch n.Op() { |
| 243 | + case ir.OIF, ir.OSWITCH: |
| 244 | + c.condLevel++ |
| 245 | + case ir.OFOR, ir.ORANGE: |
| 246 | + c.condLevel-- |
| 247 | + } |
| 248 | +} |
| 249 | + |
| 250 | +func (c *condLevelTracker) post(n ir.Node) { |
| 251 | + switch n.Op() { |
| 252 | + case ir.OFOR, ir.ORANGE: |
| 253 | + c.condLevel++ |
| 254 | + case ir.OIF: |
| 255 | + c.condLevel-- |
| 256 | + case ir.OSWITCH: |
| 257 | + c.condLevel-- |
| 258 | + } |
| 259 | +} |
0 commit comments