11from __future__ import annotations
22
3+ from typing import Callable , Sequence
4+
35import torch
46import torch .nn as nn
57import torch .nn .functional as F
@@ -35,6 +37,7 @@ def __init__(
3537 trunk_hidden : list [int ] = [1024 , 512 ],
3638 out_hidden : list [int ] = [256 , 128 , 64 ],
3739 robust : bool = False ,
40+ embedding_aggregations : Sequence [str ] = ("mean" ,),
3841 ** kwargs ,
3942 ) -> None :
4043 """Initialize the Wrenformer model.
@@ -57,6 +60,9 @@ def __init__(
5760 target will be an estimate for the aleatoric uncertainty (uncertainty inherent to
5861 the sample) which can be used with a robust loss function to attenuate the weighting
5962 of uncertain samples.
63+ embedding_aggregations (list[str]): Aggregations to apply to the learned embedding
64+ returned by the transformer encoder before passing into the ResidualNetwork. One or
65+ more of ['mean', 'std', 'sum', 'min', 'max']. Defaults to ['mean'].
6066 """
6167 super ().__init__ (robust = robust , ** kwargs )
6268
@@ -73,9 +79,10 @@ def __init__(
7379 if self .robust :
7480 n_targets = [2 * n for n in n_targets ]
7581
76- n_aggregators = 2 # number of embedding aggregation functions
82+ self . embedding_aggregations = embedding_aggregations
7783 self .trunk_nn = ResidualNetwork (
78- input_dim = n_aggregators * d_model ,
84+ # len(embedding_aggregations) = number of catted tensors in aggregated_embeddings below
85+ input_dim = len (embedding_aggregations ) * d_model ,
7986 output_dim = out_hidden [0 ],
8087 hidden_layer_dims = trunk_hidden ,
8188 )
@@ -123,18 +130,25 @@ def forward( # type: ignore
123130 # into a single vector Wyckoff embedding
124131 # careful to ignore padded values when taking the mean
125132 inv_mask : torch .BoolTensor = ~ mask [..., None ]
126- # sum_agg = (embeddings * inv_mask).sum(dim=1)
127-
128- # # replace padded values with +/-inf to exclude them from min/max
129- # min_agg, _ = torch.where(inv_mask, embeddings, float("inf")).min(dim=1)
130- # max_agg, _ = torch.where(inv_mask, embeddings, float("-inf")).max(dim=1)
131- mean_agg = masked_mean (embeddings , inv_mask , dim = 1 )
132- std_agg = masked_std (embeddings , inv_mask , dim = 1 )
133133
134- # Sum+Std+Min+Max+Mean: we call this S2M3 aggregation
135- aggregated_embeddings = torch .cat ([mean_agg , std_agg ], dim = 1 )
134+ aggregation_funcs = [aggregators [key ] for key in self .embedding_aggregations ]
135+ aggregated_embeddings = torch .cat (
136+ [func (embeddings , inv_mask , 1 ) for func in aggregation_funcs ], dim = 1
137+ )
136138
137139 # main body of the feed-forward NN jointly used by all multitask objectives
138140 predictions = F .relu (self .trunk_nn (aggregated_embeddings ))
139141
140142 return tuple (output_nn (predictions ) for output_nn in self .output_nns )
143+
144+
145+ # using all at once we call this S2M3 aggregation
146+ aggregators : dict [str , Callable [[Tensor , BoolTensor , int ], Tensor ]] = {
147+ "mean" : masked_mean ,
148+ "sum" : lambda x , mask , dim : (x * mask ).sum (dim = dim ),
149+ "std" : masked_std ,
150+ # replace padded values with +/-inf to make sure min()/max() ignore them
151+ "min" : lambda x , mask , dim : torch .where (mask , x , float ("inf" )).min (dim = dim )[0 ],
152+ # 1st ret val = max, 2nd ret val = max indices
153+ "max" : lambda x , mask , dim : torch .where (mask , x , float ("-inf" )).max (dim = dim )[0 ],
154+ }
0 commit comments