Skip to content

Commit 47e607a

Browse files
committed
fix RuntimeError: expected scalar type float but found double in core.py masked_(min|max)
in "min": lambda x, mask, dim: torch.where(mask, x, float("inf")).min(dim=dim)[0],
1 parent 7dbaf04 commit 47e607a

File tree

3 files changed

+47
-24
lines changed

3 files changed

+47
-24
lines changed

aviary/core.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414
from sklearn.metrics import f1_score
15-
from torch import Tensor
15+
from torch import BoolTensor, Tensor
1616
from torch.nn.functional import softmax
1717
from torch.utils.data import DataLoader
1818
from torch.utils.tensorboard import SummaryWriter
@@ -577,22 +577,20 @@ def np_one_hot(targets: np.ndarray, n_classes: int = None) -> np.ndarray:
577577
return np.eye(n_classes)[targets]
578578

579579

580-
def masked_std(
581-
x: torch.Tensor, mask: torch.BoolTensor, dim: int = 0, eps: float = 1e-12
582-
) -> torch.Tensor:
580+
def masked_std(x: Tensor, mask: BoolTensor, dim: int = 0, eps: float = 1e-12) -> Tensor:
583581
"""Compute the standard deviation of a tensor, ignoring masked values.
584582
585583
Args:
586-
x (torch.Tensor): Tensor to compute standard deviation of.
587-
mask (torch.BoolTensor): Same shape as x with True where x is valid and False
584+
x (Tensor): Tensor to compute standard deviation of.
585+
mask (BoolTensor): Same shape as x with True where x is valid and False
588586
where x should be masked. Mask should not be all False in any column of
589587
dimension dim to avoid NaNs.
590588
dim (int, optional): Dimension to take std of. Defaults to 0.
591589
eps (float, optional): Small positive number to ensure std is differentiable.
592590
Defaults to 1e-12.
593591
594592
Returns:
595-
torch.Tensor: Same shape as x, except dimension dim reduced.
593+
Tensor: Same shape as x, except dimension dim reduced.
596594
"""
597595
mean = masked_mean(x, mask, dim=dim)
598596
squared_diff = (x - mean.unsqueeze(dim=dim)) ** 2
@@ -601,18 +599,42 @@ def masked_std(
601599
return std
602600

603601

604-
def masked_mean(x: torch.Tensor, mask: torch.BoolTensor, dim: int = 0) -> torch.Tensor:
602+
def masked_mean(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
605603
"""Compute the mean of a tensor, ignoring masked values.
606604
607605
Args:
608-
x (torch.Tensor): Tensor to compute standard deviation of.
609-
mask (torch.BoolTensor): Same shape as x with True where x is valid and False
606+
x (Tensor): Tensor to compute mean of.
607+
mask (BoolTensor): Same shape as x with True where x is valid and False
610608
where x should be masked. Mask should not be all False in any column of
611609
dimension dim to avoid NaNs from zero division.
612610
dim (int, optional): Dimension to take mean of. Defaults to 0.
613611
614612
Returns:
615-
torch.Tensor: Same shape as x, except dimension dim reduced.
613+
Tensor: Same shape as x, except dimension dim reduced.
616614
"""
617-
x_nan = x.masked_fill(~mask, float("nan"))
615+
# for safety, we could add this assert but might impact performance
616+
# assert (
617+
# mask.sum(dim=dim).ne(0).all()
618+
# ), "mask should not be all False in any column, causes zero division"
619+
x_nan = x.float().masked_fill(~mask, float("nan"))
618620
return x_nan.nanmean(dim=dim)
621+
622+
623+
def masked_max(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
624+
"""Compute the max of a tensor along dimension dim, ignoring values at indices where
625+
mask is False. See masked_mean docstring for Args details.
626+
"""
627+
# replace padded values with +/-inf to make sure min()/max() ignore them
628+
x_inf = x.float().masked_fill(~mask, float("-inf"))
629+
# 1st ret val = max, 2nd ret val = max indices
630+
x_max, _ = x_inf.max(dim=dim)
631+
return x_max
632+
633+
634+
def masked_min(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
635+
"""Compute the min of a tensor along dimension dim, ignoring values at indices where
636+
mask is False. See masked_mean docstring for Args details.
637+
"""
638+
x_inf = x.float().masked_fill(~mask, float("inf"))
639+
x_min, _ = x_inf.min(dim=dim)
640+
return x_min

aviary/wrenformer/model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from torch import BoolTensor, Tensor
99

10-
from aviary.core import BaseModelClass, masked_mean, masked_std
10+
from aviary.core import BaseModelClass, masked_max, masked_mean, masked_min, masked_std
1111
from aviary.networks import ResidualNetwork
1212

1313

@@ -142,13 +142,11 @@ def forward( # type: ignore
142142
return tuple(output_nn(predictions) for output_nn in self.output_nns)
143143

144144

145-
# using all at once we call this S2M3 aggregation
145+
# map aggregation types to functions
146146
aggregators: dict[str, Callable[[Tensor, BoolTensor, int], Tensor]] = {
147147
"mean": masked_mean,
148-
"sum": lambda x, mask, dim: (x * mask).sum(dim=dim),
149148
"std": masked_std,
150-
# replace padded values with +/-inf to make sure min()/max() ignore them
151-
"min": lambda x, mask, dim: torch.where(mask, x, float("inf")).min(dim=dim)[0],
152-
# 1st ret val = max, 2nd ret val = max indices
153-
"max": lambda x, mask, dim: torch.where(mask, x, float("-inf")).max(dim=dim)[0],
149+
"max": masked_max,
150+
"min": masked_min,
151+
"sum": lambda x, mask, dim: (x * mask).sum(dim=dim),
154152
}

examples/wrenformer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,15 @@ def run_wrenformer(
166166
# the element type (usually 200-dim matscholar embeddings) and Wyckoff position (see
167167
# 'bra-alg-off.json') + 1 for the weight of that element/Wyckoff position in the
168168
# material's composition
169-
n_features = features[0].shape[-1]
170-
assert n_features in (200 + 1, 200 + 1 + 444) # Roost and Wren embedding size resp.
169+
embedding_len = features[0].shape[-1]
170+
assert embedding_len in (
171+
200 + 1,
172+
200 + 1 + 444,
173+
) # Roost and Wren embedding size resp.
171174

172175
model = Wrenformer(
173176
n_targets=[1 if task_type == reg_key else 2],
174-
n_features=n_features,
177+
n_features=embedding_len,
175178
task_dict={target_col: task_type}, # e.g. {'exfoliation_en': 'regression'}
176179
n_attn_layers=n_attn_layers,
177180
robust=robust,
@@ -201,14 +204,14 @@ def run_wrenformer(
201204
"target": target_col,
202205
"warmup_steps": warmup_steps,
203206
"robust": robust,
204-
"n_features": n_features, # embedding size
207+
"embedding_len": embedding_len,
205208
"losses": str(loss_dict),
206209
"training_samples": len(train_df),
207210
"test_samples": len(test_df),
208211
"trainable_params": model.num_params,
209212
"swa_start": swa_start,
210213
"timestamp": timestamp,
211-
"embedding_aggregations": embedding_aggregations,
214+
"embedding_aggregations": ",".join(embedding_aggregations),
212215
**(run_params or {}),
213216
}
214217

0 commit comments

Comments
 (0)