Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 63 additions & 15 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,58 @@

logger: logging.Logger = logging.getLogger(__name__)

RES_ENABLED_TABLES_STR = "res_enabled_tables"
RES_STORE_SHARDS_STR = "res_store_shards"
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"


def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
# populate res_params, which is used for raw embedding streaming
# here only populates the params available in fused_params and TBE configs
res_params: RESParams = RESParams()
fused_params = config.fused_params or {}
# read and clean up the fused_params that are not in the constructor
if RES_STORE_SHARDS_STR in fused_params:
res_params.res_store_shards = fused_params.get(RES_STORE_SHARDS_STR)
del fused_params[RES_STORE_SHARDS_STR]
res_enabled_tables: Optional[List[str]] = None
if RES_ENABLED_TABLES_STR in fused_params:
res_enabled_tables = (
fused_params.get(RES_ENABLED_TABLES_STR).split(",")
if fused_params.get(RES_ENABLED_TABLES_STR) is not None
else None
)
del fused_params[RES_ENABLED_TABLES_STR]
enable_raw_embedding_streaming: Optional[bool] = None
if ENABLE_RAW_EMBEDDING_STREAMING_STR in fused_params:
enable_raw_embedding_streaming = fused_params.get(
ENABLE_RAW_EMBEDDING_STREAMING_STR
)

if (
enable_raw_embedding_streaming is None
or enable_raw_embedding_streaming is False
):
return (False, res_params)
res_params.table_names = [table.name for table in config.embedding_tables]
if res_enabled_tables is not None and len(res_enabled_tables) != 0:
if len(set(res_enabled_tables) & set(res_params.table_names)) == 0:
logger.info(
f"No table is enabled for raw embedding streaming, "
f"raw embedding streaming is disabled, {res_enabled_tables=} {res_params.table_names=}"
)
return (False, res_params)
res_params.table_offsets = []
for emb_tbl in config.embedding_tables:
local_metadata = emb_tbl.local_metadata
if (
local_metadata is not None
and local_metadata.shard_offsets is not None
and len(local_metadata.shard_offsets) >= 1
):
res_params.table_offsets.append(local_metadata.shard_offsets[0])
return (enable_raw_embedding_streaming, res_params)


def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -186,22 +238,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]

# populate res_params, which is used for raw embedding streaming
# here only populates the params available in fused_params and TBE configs
res_params: RESParams = RESParams()
res_params.table_names = [table.name for table in config.embedding_tables]
res_params.table_offsets = []
for emb_tbl in config.embedding_tables:
local_metadata = emb_tbl.local_metadata
if (
local_metadata is not None
and local_metadata.shard_offsets is not None
and len(local_metadata.shard_offsets) >= 1
):
res_params.table_offsets.append(local_metadata.shard_offsets[0])
if "res_store_shards" in fused_params:
res_params.res_store_shards = fused_params.get("res_store_shards")
enable_res, res_params = _populate_res_params(config)
ssd_tbe_params["res_params"] = res_params
ssd_tbe_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res

return ssd_tbe_params

Expand Down Expand Up @@ -2190,6 +2229,9 @@ def __init__(
if "cache_precision" not in fused_params:
fused_params["cache_precision"] = weights_precision

enable_res, res_params = _populate_res_params(config)
fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res

self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
Expand All @@ -2208,6 +2250,7 @@ def __init__(
self._col_offset,
)
),
res_params=res_params,
**fused_params,
)
)
Expand Down Expand Up @@ -3041,6 +3084,10 @@ def __init__(
fused_params["cache_precision"] = weights_precision
if weights_precision == SparseType.NFP8:
fused_params["cache_precision"] = SparseType.FP16

enable_res, res_params = _populate_res_params(config)
fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res

self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
Expand All @@ -3059,6 +3106,7 @@ def __init__(
self._col_offset,
)
),
res_params=res_params,
**fused_params,
)
)
Expand Down
Loading