|
8 | 8 | from torch_tensorrt.dynamo.conversion import impl |
9 | 9 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
10 | 10 | from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 11 | + cast_trt_tensor, |
11 | 12 | get_positive_dim, |
12 | 13 | get_trt_tensor, |
13 | 14 | to_numpy, |
@@ -440,3 +441,70 @@ def get_softmax_dim(ndim: int) -> int: |
440 | 441 | layer.axes = 1 << dim |
441 | 442 | set_layer_name(layer, target, name, source_ir) |
442 | 443 | return layer.get_output(0) |
| 444 | + |
| 445 | + |
| 446 | +def pdist( |
| 447 | + ctx: ConversionContext, |
| 448 | + target: Target, |
| 449 | + source_ir: Optional[SourceIR], |
| 450 | + name: str, |
| 451 | + input: TRTTensor, |
| 452 | + p: float = 2, |
| 453 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 454 | + shape = input.shape |
| 455 | + extend_input = impl.shuffle.reshape( |
| 456 | + ctx, |
| 457 | + target, |
| 458 | + source_ir, |
| 459 | + f"{name}_reshape", |
| 460 | + input, |
| 461 | + shape=shape[0:1] + (1,) + shape[1:], |
| 462 | + ) |
| 463 | + x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input) |
| 464 | + |
| 465 | + if p == 0: |
| 466 | + # norm = torch.sum(x!=0, dim=2) |
| 467 | + nonzero_val = impl.elementwise.ne(ctx, target, source_ir, f"{name}_ne", x, 0) |
| 468 | + norm = impl.reduce.sum( |
| 469 | + ctx, target, source_ir, f"{name}_sum", nonzero_val, dim=2, keepdim=False |
| 470 | + ) |
| 471 | + norm = cast_trt_tensor( |
| 472 | + ctx, norm, torch.float32, f"{name}_cast", target, source_ir |
| 473 | + ) |
| 474 | + elif p == 1: |
| 475 | + # norm = torch.sum(torch.abs(x), dim=2) |
| 476 | + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) |
| 477 | + norm = impl.reduce.sum( |
| 478 | + ctx, target, source_ir, f"{name}_sum", abs_val, dim=2, keepdim=False |
| 479 | + ) |
| 480 | + elif 0 < p < 1 or 1 < p < float("inf"): |
| 481 | + # norm = torch.pow(torch.sum(torch.pow(torch.abs(x), p), dim=2), 1/p) |
| 482 | + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) |
| 483 | + pow_val = impl.elementwise.pow( |
| 484 | + ctx, target, source_ir, f"{name}_pow1", abs_val, p |
| 485 | + ) |
| 486 | + sum_val = impl.reduce.sum( |
| 487 | + ctx, target, source_ir, f"{name}_sum", pow_val, dim=2, keepdim=False |
| 488 | + ) |
| 489 | + norm = impl.elementwise.pow( |
| 490 | + ctx, target, source_ir, f"{name}_pow2", sum_val, 1 / p |
| 491 | + ) |
| 492 | + elif p == float("inf"): |
| 493 | + # norm = torch.max(torch.abs(x)) |
| 494 | + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs", x) |
| 495 | + norm = impl.reduce.max( |
| 496 | + ctx, |
| 497 | + target, |
| 498 | + source_ir, |
| 499 | + f"{name}_max", |
| 500 | + abs_val, |
| 501 | + dim=2, |
| 502 | + keepdim=False, |
| 503 | + return_indices=False, |
| 504 | + ) |
| 505 | + else: |
| 506 | + raise RuntimeError( |
| 507 | + f"p should between [0, inf], currently p={p} is not supported!" |
| 508 | + ) |
| 509 | + indices = np.triu_indices(shape[0], k=1) |
| 510 | + return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) |
0 commit comments