Skip to content

Commit e5c15ea

Browse files
committed
Added new field detX_delta in apply_detX()
1 parent e870579 commit e5c15ea

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

weightwatcher/weightwatcher.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2918,6 +2918,14 @@ def apply_detX(self, ww_layer, params=None):
29182918

29192919
alpha = ww_layer.alpha
29202920

2921+
# Calculate the value of xmin - detX_val in normalized units.
2922+
detX_delta = None
2923+
if ww_layer.xmin:
2924+
xmin = ww_layer.xmin * np.max(evals)/ww_layer.xmax
2925+
detX_delta = xmin - detX_val
2926+
else:
2927+
xmin = None
2928+
29212929

29222930
if WW_PLOT_DETX in plot:
29232931
name = ww_layer.name
@@ -2932,7 +2940,6 @@ def apply_detX(self, ww_layer, params=None):
29322940
plt.axvline(np.log10(detX_val), color='purple', label=r"detX$=1$")
29332941

29342942
if ww_layer.xmin:
2935-
xmin = ww_layer.xmin * np.max(evals)/ww_layer.xmax
29362943
plt.axvline(np.log10(xmin), color='red', label=r"PL $\lambda_{min}$")
29372944

29382945
plt.legend()
@@ -2949,6 +2956,7 @@ def apply_detX(self, ww_layer, params=None):
29492956

29502957
ww_layer.add_column('detX_val', detX_val)
29512958
ww_layer.add_column('detX_val_unrescaled', detX_val_unrescaled)
2959+
ww_layer.add_column('detX_delta', detX_delta)
29522960

29532961
return ww_layer
29542962

0 commit comments

Comments
 (0)