diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index cad7319eb..00e085f31 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -98,6 +98,58 @@ logger: logging.Logger = logging.getLogger(__name__) +RES_ENABLED_TABLES_STR = "res_enabled_tables" +RES_STORE_SHARDS_STR = "res_store_shards" +ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming" + + +def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]: + # populate res_params, which is used for raw embedding streaming + # here only populates the params available in fused_params and TBE configs + res_params: RESParams = RESParams() + fused_params = config.fused_params or {} + # read and clean up the fused_params that are not in the constructor + if RES_STORE_SHARDS_STR in fused_params: + res_params.res_store_shards = fused_params.get(RES_STORE_SHARDS_STR) + del fused_params[RES_STORE_SHARDS_STR] + res_enabled_tables: Optional[List[str]] = None + if RES_ENABLED_TABLES_STR in fused_params: + res_enabled_tables = ( + fused_params.get(RES_ENABLED_TABLES_STR).split(",") + if fused_params.get(RES_ENABLED_TABLES_STR) is not None + else None + ) + del fused_params[RES_ENABLED_TABLES_STR] + enable_raw_embedding_streaming: Optional[bool] = None + if ENABLE_RAW_EMBEDDING_STREAMING_STR in fused_params: + enable_raw_embedding_streaming = fused_params.get( + ENABLE_RAW_EMBEDDING_STREAMING_STR + ) + + if ( + enable_raw_embedding_streaming is None + or enable_raw_embedding_streaming is False + ): + return (False, res_params) + res_params.table_names = [table.name for table in config.embedding_tables] + if res_enabled_tables is not None and len(res_enabled_tables) != 0: + if len(set(res_enabled_tables) & set(res_params.table_names)) == 0: + logger.info( + f"No table is enabled for raw embedding streaming, " + f"raw embedding streaming is disabled, {res_enabled_tables=} {res_params.table_names=}" + ) + return (False, res_params) + res_params.table_offsets = [] + for emb_tbl in config.embedding_tables: + local_metadata = emb_tbl.local_metadata + if ( + local_metadata is not None + and local_metadata.shard_offsets is not None + and len(local_metadata.shard_offsets) >= 1 + ): + res_params.table_offsets.append(local_metadata.shard_offsets[0]) + return (enable_raw_embedding_streaming, res_params) + def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: """ @@ -186,22 +238,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ssd_tbe_params["cache_sets"] = int(max_cache_sets) ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables] - # populate res_params, which is used for raw embedding streaming - # here only populates the params available in fused_params and TBE configs - res_params: RESParams = RESParams() - res_params.table_names = [table.name for table in config.embedding_tables] - res_params.table_offsets = [] - for emb_tbl in config.embedding_tables: - local_metadata = emb_tbl.local_metadata - if ( - local_metadata is not None - and local_metadata.shard_offsets is not None - and len(local_metadata.shard_offsets) >= 1 - ): - res_params.table_offsets.append(local_metadata.shard_offsets[0]) - if "res_store_shards" in fused_params: - res_params.res_store_shards = fused_params.get("res_store_shards") + enable_res, res_params = _populate_res_params(config) ssd_tbe_params["res_params"] = res_params + ssd_tbe_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res return ssd_tbe_params @@ -2190,6 +2229,9 @@ def __init__( if "cache_precision" not in fused_params: fused_params["cache_precision"] = weights_precision + enable_res, res_params = _populate_res_params(config) + fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( @@ -2208,6 +2250,7 @@ def __init__( self._col_offset, ) ), + res_params=res_params, **fused_params, ) ) @@ -3041,6 +3084,10 @@ def __init__( fused_params["cache_precision"] = weights_precision if weights_precision == SparseType.NFP8: fused_params["cache_precision"] = SparseType.FP16 + + enable_res, res_params = _populate_res_params(config) + fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( @@ -3059,6 +3106,7 @@ def __init__( self._col_offset, ) ), + res_params=res_params, **fused_params, ) )