@@ -49,6 +49,18 @@ def geninp():
4949 return input_dict
5050
5151
52+ def get_padded_stride (shape , alignment_bytes , pad_output , itemsize ):
53+ align = alignment_bytes // itemsize
54+ new_strides = [0 for _ in range (len (shape ))]
55+ new_strides [len (shape ) - 1 ] = 1
56+ for i in range (len (shape ) - 1 , 0 , - 1 ):
57+ stride = shape [i ] * new_strides [i ]
58+ if pad_output and stride % align != 0 :
59+ stride = (stride + align - 1 ) // align * align
60+ new_strides [i - 1 ] = stride
61+ return tuple (new_strides )
62+
63+
5264class LinearAndSoftmax (nn .Module ):
5365 """
5466 It's very common that a transformer model will do a matmul and then
@@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
745757 input_tensors = [get_input (shape , alignment_bytes ) for _ in range (num_inputs )]
746758
747759 config_patches = {
748- "compile_threads" : 1 ,
749760 "comprehensive_padding" : pad_output ,
750761 "cpu_backend" : "triton" ,
751- "disable_padding_cpu" : False ,
752- "implicit_fallbacks" : False ,
753- "inplace_buffers" : False ,
754762 "padding_alignment_bytes" : alignment_bytes ,
755- "pad_channels_last" : True ,
756763 "pad_outputs" : True ,
757764 "padding_stride_threshold" : 0 ,
758- "triton.prefer_nd_tiling" : True ,
759- "triton.use_block_ptr" : True ,
760- "triton.codegen_upcast_to_fp32" : False ,
761- "unroll_reductions_threshold" : 1 ,
762765 }
763766 with config .patch (config_patches ):
764767 compiled = torch .compile (torch .cat )
@@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
767770 output_shape = (shape [0 ] * num_inputs , shape [1 ])
768771 output_stride = input_tensors [0 ].stride ()
769772 output_line = f"buf12 = empty_strided_{ GPU_TYPE } ({ output_shape } , { output_stride } , torch.float32)"
770- self .assertTrue (any (output_line in line for line in code ))
773+ self .assertTrue (output_line in code [0 ])
774+
775+ @parametrize (
776+ "shape,alignment_bytes,pad_output" ,
777+ [
778+ ((512 , 1 ), 32 , False ),
779+ ((512 , 1 ), 32 , True ),
780+ ((32 , 30 ), 64 , False ),
781+ ((32 , 30 ), 64 , True ),
782+ ((512 , 100 , 1 ), 32 , False ),
783+ ((512 , 100 , 1 ), 32 , True ),
784+ ((32 , 50 , 30 ), 64 , False ),
785+ ((32 , 50 , 30 ), 64 , True ),
786+ ],
787+ )
788+ def test_outer_dynamic_shape_padding (self , shape , alignment_bytes , pad_output ):
789+ """
790+ When only the outermost dim is dynamic shape, the output can still be padded up
791+ based on padding configuration.
792+ """
793+ num_inputs = 2
794+ input_tensors = [
795+ torch .randn (shape , dtype = torch .float32 ) for _ in range (num_inputs )
796+ ]
797+
798+ config_patches = {
799+ "comprehensive_padding" : pad_output ,
800+ "cpu_backend" : "triton" ,
801+ "padding_alignment_bytes" : alignment_bytes ,
802+ "pad_outputs" : True ,
803+ "padding_stride_threshold" : 0 ,
804+ }
805+ with config .patch (config_patches ):
806+ torch ._dynamo .mark_dynamic (input_tensors [0 ], 0 )
807+ torch ._dynamo .mark_dynamic (input_tensors [1 ], 0 )
808+ compiled = torch .compile (torch .add )
809+ result , _ = run_and_get_code (compiled , * input_tensors )
810+
811+ expected_stride = get_padded_stride (
812+ result .shape , alignment_bytes , pad_output , result .dtype .itemsize
813+ )
814+ self .assertEqual (result .stride (), expected_stride )
815+
816+ @parametrize (
817+ "shape,alignment_bytes,pad_output" ,
818+ [
819+ ((500 , 10 , 1 ), 32 , False ),
820+ ((500 , 20 , 1 ), 32 , True ),
821+ ((30 , 10 , 20 ), 64 , True ),
822+ ((30 , 10 , 20 ), 64 , False ),
823+ ],
824+ )
825+ def test_perm_outer_dynamic_shape_padding (self , shape , alignment_bytes , pad_output ):
826+ """
827+ When only the outermost dim is dynamic shape, the output can still be padded up
828+ based on padding configuration. Test when this occurs after a permute op.
829+ """
830+
831+ def permute_contig (x ):
832+ return torch .transpose (x , 0 , 2 ).contiguous ()
833+
834+ num_inputs = 1
835+ input_tensors = [
836+ torch .randn (shape , dtype = torch .float32 ) for _ in range (num_inputs )
837+ ]
838+
839+ config_patches = {
840+ "comprehensive_padding" : pad_output ,
841+ "cpu_backend" : "triton" ,
842+ "padding_alignment_bytes" : alignment_bytes ,
843+ "pad_outputs" : True ,
844+ "padding_stride_threshold" : 0 ,
845+ "triton.use_block_ptr" : True ,
846+ }
847+ with config .patch (config_patches ):
848+ torch ._dynamo .mark_dynamic (input_tensors [0 ], 2 )
849+ compiled = torch .compile (permute_contig )
850+ result , _ = run_and_get_code (compiled , * input_tensors )
851+
852+ expected_stride = get_padded_stride (
853+ result .shape , alignment_bytes , pad_output , result .dtype .itemsize
854+ )
855+ self .assertEqual (result .stride (), expected_stride )
771856
772857
773858if __name__ == "__main__" :
0 commit comments