Skip to content

Commit dec6b31

Browse files
malfetguilhermeleobas
authored andcommitted
Update Trition pin (pytorch#115743)
To include a cherry-pick of triton-lang/triton#2771 that should fix cuda-11.8 runtime issues Also, tweak build wheel script to update both ROCm and vanilla Trition builds version to 2.2 (even though on trunk it should probably be 3.3 already) TODO: Remove `ROCM_TRITION_VERSION` once both trunk and ROCM version are in sync again Pull Request resolved: pytorch#115743 Approved by: https://github.com/davidberard98
1 parent c088b9a commit dec6b31

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

.ci/docker/ci_commit_pins/triton.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
bcad9dabe15021c53b6a88296e9d7a210044f108
1+
e28a256d71f3cf2bcc7b69d6bda73a9b855e385e

.ci/docker/triton_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.1.0
1+
2.2.0

.github/scripts/build_triton_wheel.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
SCRIPT_DIR = Path(__file__).parent
1111
REPO_DIR = SCRIPT_DIR.parent.parent
1212

13+
# TODO: Remove me once Triton version is again in sync for vanilla and ROCm
14+
ROCM_TRITION_VERSION = "2.1.0"
15+
1316

1417
def read_triton_pin(rocm_hash: bool = False) -> str:
1518
triton_file = "triton.txt" if not rocm_hash else "triton-rocm.txt"
@@ -29,25 +32,37 @@ def check_and_replace(inp: str, src: str, dst: str) -> str:
2932
return inp.replace(src, dst)
3033

3134

32-
def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
35+
def patch_setup_py(
36+
path: Path,
37+
*,
38+
version: str,
39+
name: str = "triton",
40+
expected_version: Optional[str] = None,
41+
) -> None:
3342
with open(path) as f:
3443
orig = f.read()
3544
# Replace name
3645
orig = check_and_replace(orig, 'name="triton",', f'name="{name}",')
3746
# Replace version
47+
if not expected_version:
48+
expected_version = read_triton_version()
3849
orig = check_and_replace(
39-
orig, f'version="{read_triton_version()}",', f'version="{version}",'
50+
orig, f'version="{expected_version}",', f'version="{version}",'
4051
)
4152
with open(path, "w") as f:
4253
f.write(orig)
4354

4455

45-
def patch_init_py(path: Path, *, version: str) -> None:
56+
def patch_init_py(
57+
path: Path, *, version: str, expected_version: Optional[str] = None
58+
) -> None:
59+
if not expected_version:
60+
expected_version = read_triton_version()
4661
with open(path) as f:
4762
orig = f.read()
4863
# Replace version
4964
orig = check_and_replace(
50-
orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"'
65+
orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"'
5166
)
5267
with open(path, "w") as f:
5368
f.write(orig)
@@ -140,6 +155,7 @@ def build_triton(
140155
patch_init_py(
141156
triton_pythondir / "triton" / "__init__.py",
142157
version=f"{version}",
158+
expected_version=ROCM_TRITION_VERSION if build_rocm else None,
143159
)
144160

145161
if build_rocm:
@@ -148,6 +164,7 @@ def build_triton(
148164
triton_pythondir / "setup.py",
149165
name=triton_pkg_name,
150166
version=f"{version}",
167+
expected_version=ROCM_TRITION_VERSION,
151168
)
152169
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
153170
print("ROCm libraries setup for triton installation...")

0 commit comments

Comments
 (0)