Skip to content

[float8] fuse abs/max with torch.linalg.vector_norm #905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Sep 19, 2024

torch.linalg.vector_norm is faster
Screenshot 2024-09-18 at 8 17 28 PM

torch.max(torch.abs) are slower because of 3 kernels and launching overhead
Screenshot 2024-09-18 at 8 17 16 PM

generated profiler traces with following code

import torch
import os
import contextlib
@contextlib.contextmanager
def enable_profiling(enable=False):
    if not enable:
        torch_profiler = contextlib.nullcontext()
        yield None
    else:
        trace_dir = "./profilers"
        rank = 0
        def trace_handler(prof):
            curr_trace_dir_name = "iteration_" + str(prof.step_num)
            curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
            if not os.path.exists(curr_trace_dir):
                os.makedirs(curr_trace_dir, exist_ok=True)
            prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
            # torch.distributed.barrier()
        if not os.path.exists(trace_dir):
            os.makedirs(trace_dir, exist_ok=True)
        warmup, active = 1, 2
        wait = 1
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
            on_trace_ready=trace_handler,
            record_shapes=True,
        ) as torch_profiler:
            yield torch_profiler

with enable_profiling(True) as torch_profiler:
    tensor = torch.rand(256, 1024, device="cuda")
    for i in range(10):
        with torch.profiler.record_function("vector_norm"):
            # amax = torch.max(torch.abs(tensor))
            torch.linalg.vector_norm(tensor, ord=float("inf"))
        torch_profiler.step()
        

Copy link

pytorch-bot bot commented Sep 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/905

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bf49f41 with merge base 53b6b78 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 19, 2024
@weifengpy weifengpy requested review from vkuzo and awgu September 19, 2024 03:21
@vkuzo
Copy link
Contributor

vkuzo commented Sep 19, 2024

torch.compile performance is important in this code path, thoughts on how does this PR impact performance when torch.compile is on?

@weifengpy
Copy link
Contributor Author

torch.compile performance is important in this code path, thoughts on how does this PR impact performance when torch.compile is on?

good question! I can benchmark on inductor output code and update again

@weifengpy weifengpy marked this pull request as draft September 19, 2024 18:51
@y-sq
Copy link
Contributor

y-sq commented Oct 4, 2024

When float8 is used, we usually compile the model. Was the issue of three kernels initially observed from the un-compiled fsdp_pre_all_gather region in a compiled model?

Besides the additional kernels for amax scaling factor, do we also see more kernels for cast_tp_fp8 part?

Is it possible to add an option in fp8-all-gather to explicitly compile the casting parts?

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Add license header to files (pytorch#905)

* need license header on every file

* apparently the validator only reports the first violation it finds not all of them

* typeo

don't put the license on these files (pytorch#909)

add license to source files (pytorch#910)

this should be the final one (pytorch#911)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants