Skip to content

Commit 5292abf

Browse files
committed
fix masked_mean() error expected x and y to be on same device
arises due to torch.tensor(float("nan")) defaulting to CPU
1 parent b72d328 commit 5292abf

File tree

3 files changed

+4
-11
lines changed

3 files changed

+4
-11
lines changed

aviary/core.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -608,14 +608,11 @@ def masked_mean(x: torch.Tensor, mask: torch.BoolTensor, dim: int = 0) -> torch.
608608
x (torch.Tensor): Tensor to compute standard deviation of.
609609
mask (torch.BoolTensor): Same shape as x with True where x is valid and False
610610
where x should be masked. Mask should not be all False in any column of
611-
dimension dim to avoid NaNs.
611+
dimension dim to avoid NaNs from zero division.
612612
dim (int, optional): Dimension to take mean of. Defaults to 0.
613613
614614
Returns:
615615
torch.Tensor: Same shape as x, except dimension dim reduced.
616616
"""
617-
# mask should be True where x is valid and False where x should be masked
618-
x_nan = torch.where(mask, x, torch.tensor(float("nan")))
619-
# torch.tensor(float("nan")) can be simplified to torch.nan in torch>=1.12
620-
617+
x_nan = x.masked_fill(~mask, float("nan"))
621618
return x_nan.nanmean(dim=dim)

examples/mat_bench/slurm_submit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@
4343
print(f"{{job_id=}}")
4444
print("{model_name=}")
4545
46-
job_array_id = os.environ.get("SLURM_ARRAY_TASK_ID")
47-
if job_array_id is not None:
48-
job_array_id = int(job_array_id)
46+
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID"), 0)
4947
print(f"{{job_array_id=}}")
5048
5149
dataset_name, fold = list(product({datasets}, {folds}))[job_array_id]

examples/mp_wbm/slurm_submit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@
3939
print("{model_name=}")
4040
print("{data_path=}")
4141
42-
job_array_id = os.environ.get("SLURM_ARRAY_TASK_ID")
43-
if job_array_id is not None:
44-
job_array_id = int(job_array_id)
42+
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID"), 0)
4543
print(f"{{job_array_id=}}")
4644
4745
run_wrenformer_on_mp_wbm(

0 commit comments

Comments
 (0)