Skip to content

Commit cde5e26

Browse files
added small N PL estimator
1 parent 02bfaf5 commit cde5e26

File tree

8 files changed

+1726
-6
lines changed

8 files changed

+1726
-6
lines changed

Changelog.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.7.6]
8+
Added small_N power law estimator
9+
710
## [0.7.5.5]
811
Fixed bug in reading large safetensors files
912

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ It can be used to:
4040

4141
And in the notebooks provided in the [examples](https://github.com/CalculatedContent/WeightWatcher/tree/master/examples) directory
4242

43-
## Installation: Version 0.7.5.5
43+
## Installation: Version 0.7.6
4444

4545
```sh
4646
pip install weightwatcher

examples/MLP3-MNIST-AdamW.ipynb

Lines changed: 725 additions & 0 deletions
Large diffs are not rendered by default.

examples/MLP3-MNIST-Muon.ipynb

Lines changed: 808 additions & 0 deletions
Large diffs are not rendered by default.

tests/test.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4539,8 +4539,7 @@ def test_compute_alphas(self):
45394539
self.assertAlmostEqual(a[1],1.66595, places=3)
45404540
self.assertAlmostEqual(a[3],1.43459, places=3)
45414541

4542-
4543-
4542+
45444543
#
45454544
# TODO: check if xmax='force' does anything ?
45464545
#
@@ -5594,6 +5593,79 @@ def test_smooth_W_numpy_singular_values(self):
55945593

55955594

55965595

5596+
from weightwatcher.WW_powerlaw import WWFit
5597+
5598+
5599+
def sample_pareto(alpha: float, xmin: float, n: int, seed: int = 123) -> np.ndarray:
5600+
"""
5601+
Sample n values from a continuous Pareto distribution with exponent alpha
5602+
and minimum value xmin, using inverse-CDF sampling.
5603+
"""
5604+
rng = np.random.default_rng(seed)
5605+
u = rng.random(n)
5606+
# Pareto(xmin, alpha): X = xmin * U^(-1 / (alpha - 1))
5607+
return xmin * (u ** (-1.0 / (alpha - 1.0)))
5608+
5609+
5610+
class TestSmallNPowerLaw(unittest.TestCase):
5611+
"""
5612+
Unit tests for the small-N power-law fitter (fit_powerlaw_smallN).
5613+
"""
5614+
5615+
def test_smalln_path_is_used(self):
5616+
"""
5617+
Ensure that for small N, WWFit routes through fit_powerlaw_smallN.
5618+
"""
5619+
# small-N sample, should be below SMALL_N_CUTOFF in WWFit
5620+
data = sample_pareto(alpha=2.2, xmin=1.0, n=10, seed=42)
5621+
5622+
called = {"hit": False}
5623+
5624+
# Monkeypatch fit_powerlaw_smallN to detect that it was called
5625+
original_smallN = WWFit.fit_powerlaw_smallN
5626+
5627+
def wrapped_smallN(self, *args, **kwargs):
5628+
called["hit"] = True
5629+
return original_smallN(self, *args, **kwargs)
5630+
5631+
WWFit.fit_powerlaw_smallN = wrapped_smallN
5632+
try:
5633+
_ = WWFit(data, distribution="power_law")
5634+
finally:
5635+
# Always restore the original method
5636+
WWFit.fit_powerlaw_smallN = original_smallN
5637+
5638+
self.assertTrue(called["hit"], "fit_powerlaw_smallN was not called for small N")
5639+
5640+
def test_smalln_alpha_reasonable_on_pareto(self):
5641+
"""
5642+
For a small-N Pareto sample with known alpha, the fitted alpha
5643+
should be in the right ballpark (within a loose tolerance).
5644+
"""
5645+
true_alpha = 2.2
5646+
xmin = 1.0
5647+
n = 10 # small-N regime
5648+
5649+
data = sample_pareto(true_alpha, xmin, n, seed=123)
5650+
fit = WWFit(data, distribution="power_law") # will use small-N path
5651+
5652+
est_alpha = fit.alpha
5653+
est_xmin = fit.xmin
5654+
5655+
print(f" estimated small N alpha {est_alpha:0.2f}")
5656+
5657+
# Sanity checks
5658+
self.assertTrue(np.isfinite(est_alpha), "Estimated alpha is not finite")
5659+
self.assertGreater(est_alpha, 1.0, "Estimated alpha must be > 1 for a valid power law")
5660+
self.assertTrue(np.isfinite(est_xmin), "Estimated xmin is not finite")
5661+
5662+
# Loose accuracy check: small N is noisy, so don't demand perfection
5663+
self.assertLess(
5664+
abs(est_alpha - true_alpha),
5665+
0.4,
5666+
f"Estimated alpha {est_alpha:.3f} too far from true alpha {true_alpha:.3f}",
5667+
)
5668+
55975669

55985670

55995671
# TODO

weightwatcher/WW_powerlaw.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'lognormal_positive': powerlaw.Lognormal_Positive,
2424
}
2525

26+
SMALL_N_CUTOFF = 20
2627

2728
import logging
2829
logger = logging.getLogger(WW_NAME)
@@ -65,6 +66,15 @@ def __str__(self):
6566
return f"WWFit({self.distribution} xmin: {self.xmin:0.04f}, alpha: {self.alpha:0.04f}, sigma: {self.sigma:0.04f}, data: {len(self.data)})"
6667

6768
def fit_power_law(self):
69+
if self.N < SMALL_N_CUTOFF:
70+
print("SMALL N PL FIT")
71+
logger.info("SMALL N PL FIT")
72+
self.fit_powerlaw_smallN()
73+
return
74+
75+
return self.fit_power_law_standard()
76+
77+
def fit_power_law_standard(self):
6878
log_data = np. log(self.data, dtype=np.float64)
6979
self.alphas = np.zeros(self.N-1, dtype=np.float64)
7080
self.Ds = np. ones(self.N-1, dtype=np.float64)
@@ -80,6 +90,108 @@ def fit_power_law(self):
8090
))
8191

8292
self.sigmas = (self.alphas - 1) / np.sqrt(self.N - np.arange(self.N-1))
93+
94+
95+
96+
def fit_powerlaw_smallN(self, k_min: int = 8, lambda_prior: float = 0.0):
97+
"""
98+
Small-N continuous power-law fit:
99+
100+
- Bias-corrected MLE: alpha_bc = 1 + (n - 1) / sum_j log(x_j / xmin)
101+
- Objective for xmin selection:
102+
J = D_ks - 0.868 / sqrt(n_tail) + lambda_prior * prior_pen
103+
where prior_pen = (alpha_bc - 2)^2 (ultra-local prior, off if lambda_prior=0)
104+
105+
No trace-log gate, no eigenvalue rescaling, no lock-to-2.
106+
"""
107+
108+
log_data = np.log(self.data, dtype=np.float64)
109+
110+
# Arrays similar to fit_power_law
111+
self.alphas = np.zeros(self.N - 1, dtype=np.float64)
112+
self.Ds = np.ones(self.N - 1, dtype=np.float64)
113+
# Objective values (for internal selection)
114+
self.Js = np.full(self.N - 1, np.inf, dtype=np.float64)
115+
116+
for i, xmin in enumerate(self.data[:-1]):
117+
n_int = self.N - i # tail size as int
118+
if n_int < k_min:
119+
continue
120+
n = float(n_int)
121+
122+
# sum_j log(x_j / xmin) for j >= i
123+
s = np.sum(log_data[i:]) - n * log_data[i]
124+
if s <= 1e-12:
125+
# degenerate tail; skip
126+
continue
127+
128+
# --- bias-corrected MLE (n-1 correction) ---
129+
alpha_bc = 1.0 + (n - 1.0) / s
130+
self.alphas[i] = alpha_bc
131+
132+
if alpha_bc <= 1.0:
133+
# invalid exponent for continuous power law; skip
134+
continue
135+
136+
# Tail data for this xmin
137+
tail = self.data[i:]
138+
139+
# Theoretical CDF for continuous power law on [xmin, ∞):
140+
# F_fit(x) = 1 - (x/xmin)^(1 - alpha), x >= xmin
141+
F_fit = 1.0 - (tail / xmin) ** (1.0 - alpha_bc)
142+
143+
# Empirical CDF: 0, 1/n, ..., (n-1)/n (matches your original style)
144+
F_emp = np.arange(n_int, dtype=np.float64) / n
145+
Dks = float(np.max(np.abs(F_emp - F_fit)))
146+
self.Ds[i] = Dks
147+
148+
# --- Objective 1A: KS-scaled tail-size encouragement ---
149+
prior_pen = (alpha_bc - 2.0) ** 2 # ultra-local prior (if lambda_prior > 0)
150+
J = Dks - 0.868 / np.sqrt(n) + lambda_prior * prior_pen
151+
self.Js[i] = J
152+
153+
# Sigma like the original code (for reporting)
154+
self.sigmas = (self.alphas - 1.0) / np.sqrt(self.N - np.arange(self.N - 1))
155+
156+
# ----- Choose best xmin by J; no fallback to fit_power_law -----
157+
if np.isfinite(self.Js).any():
158+
j_best = int(np.nanargmin(self.Js))
159+
else:
160+
# If k_min was too strict and no candidate survived, use all data as tail (i=0)
161+
j_best = 0
162+
xmin = self.data[0]
163+
n_int = self.N
164+
n = float(n_int)
165+
s = np.sum(log_data) - n * log_data[0]
166+
if s <= 1e-12:
167+
# pathological case; keep trivial defaults
168+
self.xmin = xmin
169+
self.alpha = 1.0
170+
self.sigma = 0.0
171+
self.D = 1.0
172+
self.data = self.data[self.data >= self.xmin]
173+
return
174+
175+
alpha_bc = 1.0 + (n - 1.0) / s
176+
self.alphas[j_best] = alpha_bc
177+
178+
tail = self.data
179+
F_fit = 1.0 - (tail / xmin) ** (1.0 - alpha_bc)
180+
F_emp = np.arange(n_int, dtype=np.float64) / n
181+
Dks = float(np.max(np.abs(F_emp - F_fit)))
182+
self.Ds[j_best] = Dks
183+
184+
prior_pen = (alpha_bc - 2.0) ** 2
185+
self.Js[j_best] = Dks - 0.868 / np.sqrt(n) + lambda_prior * prior_pen
186+
187+
# Commit winner (similar to what __init__ does after fit_power_law)
188+
self.xmin = self.data[j_best]
189+
self.alpha = self.alphas[j_best]
190+
self.sigma = self.sigmas[j_best]
191+
self.D = self.Ds[j_best]
192+
193+
# Match powerlaw package behavior: restrict data to data >= xmin
194+
self.data = self.data[self.data >= self.xmin]
83195

84196
def __getattr__(self, item):
85197
""" Needed for replicating the behavior of the powerlaw.Fit class"""

weightwatcher/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
__name__ = "weightwatcher"
21-
__version__ = "0.7.5.5"
21+
__version__ = "0.7.6"
2222
__license__ = "Apache License, Version 2.0"
2323
__description__ = "Diagnostic Tool for Deep Neural Networks"
2424
__url__ = "https://calculationconsulting.com/"

weightwatcher/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@
141141

142142

143143
MIN_EVALS = 'min_evals'
144-
DEFAULT_MIN_EVALS = 10
145-
MIN_NUM_EVALS = 10
144+
DEFAULT_MIN_EVALS = 8
145+
MIN_NUM_EVALS = 8
146146

147147
MAX_EVALS = 'max_evals'
148148
DEFAULT_MAX_EVALS = 15000

0 commit comments

Comments
 (0)