Skip to content

Commit 6c1f90d

Browse files
authored
Merge branch 'main' into fix_logging
2 parents 14c252b + c178637 commit 6c1f90d

File tree

9 files changed

+161
-9
lines changed

9 files changed

+161
-9
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,102 @@ def call_operator(
21102110
return super().call_operator(op, args, kwargs, meta)
21112111

21122112

2113+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2114+
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2115+
"""
2116+
Replace the gelu op with an approximate gelu op. The approximate gelu op
2117+
is more efficient on DSP backends.
2118+
"""
2119+
2120+
def call_operator(
2121+
self,
2122+
op,
2123+
args: Tuple[Argument, ...],
2124+
kwargs: Dict[str, Argument],
2125+
meta: NodeMetadata,
2126+
) -> ProxyValue:
2127+
if op not in {
2128+
exir_ops.edge.aten.gelu.default,
2129+
}:
2130+
return super().call_operator(op, args, kwargs, meta)
2131+
2132+
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2133+
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2134+
2135+
# Get 0.5 * x
2136+
half = super().call_operator(
2137+
exir_ops.edge.aten.mul.Tensor,
2138+
(args[0], 0.5),
2139+
{},
2140+
meta,
2141+
)
2142+
2143+
scaled = super().call_operator(
2144+
exir_ops.edge.aten.mul.Tensor,
2145+
(args[0], 0.044715),
2146+
{},
2147+
meta,
2148+
)
2149+
2150+
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2151+
# it is much more efficient on DSP backends)
2152+
scaled_square = super().call_operator(
2153+
exir_ops.edge.aten.mul.Tensor,
2154+
(scaled, args[0]),
2155+
{},
2156+
meta,
2157+
)
2158+
2159+
# Get x^3
2160+
scaled_cubed = super().call_operator(
2161+
exir_ops.edge.aten.mul.Tensor,
2162+
(scaled_square, args[0]),
2163+
{},
2164+
meta,
2165+
)
2166+
2167+
# Get x + 0.044715 * x^3
2168+
inner_sum = super().call_operator(
2169+
exir_ops.edge.aten.add.Tensor,
2170+
(scaled_cubed, args[0]),
2171+
{},
2172+
meta,
2173+
)
2174+
2175+
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2176+
scaled_sum = super().call_operator(
2177+
exir_ops.edge.aten.mul.Tensor,
2178+
(inner_sum, 0.7978845608028654),
2179+
{},
2180+
meta,
2181+
)
2182+
2183+
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2184+
tanh = super().call_operator(
2185+
exir_ops.edge.aten.tanh.default,
2186+
(scaled_sum,),
2187+
{},
2188+
meta,
2189+
)
2190+
2191+
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2192+
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2193+
outer_sum = super().call_operator(
2194+
exir_ops.edge.aten.add.Tensor,
2195+
(tanh, 1.0),
2196+
{},
2197+
meta,
2198+
)
2199+
2200+
# Retunr the final result
2201+
return super().call_operator(
2202+
exir_ops.edge.aten.mul.Tensor,
2203+
(half, outer_sum),
2204+
{},
2205+
meta,
2206+
)
2207+
2208+
21132209
# This class encapsulates all the functions that replace/switch one op in the
21142210
# graph with another.
21152211
class CadenceReplaceOpsInGraph:
@@ -2149,4 +2245,5 @@ class CadenceReplaceOpsInGraph:
21492245
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
21502246
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
21512247
ReplaceWhereWithFullArgsWithWhereScalar,
2248+
# ReplaceGeluWithApproximateGeluPass,
21522249
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ReplaceConvWithIm2RowAndLinear,
3030
ReplaceEmptyTensorsWithFullPass,
3131
ReplaceFunctionallyEquivalentOpTargets,
32+
ReplaceGeluWithApproximateGeluPass,
3233
ReplaceIm2RowWithViewPass,
3334
ReplaceLinearWithFullyConnectedOpPass,
3435
ReplaceMMWithAddMMPass,
@@ -1301,6 +1302,41 @@ def forward(self, cond: torch.Tensor):
13011302
1,
13021303
)
13031304

1305+
def test_replace_aten_gelu_with_approximate_gelu(self):
1306+
class Gelu(torch.nn.Module):
1307+
def forward(self, input):
1308+
return torch.nn.functional.gelu(input)
1309+
1310+
inputs = torch.randn(2, 1, 64)
1311+
1312+
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1313+
1314+
p = ReplaceGeluWithApproximateGeluPass()
1315+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1316+
1317+
# Assert that aten.gelu op was decomposed
1318+
self.assertEqual(
1319+
count_node(
1320+
graph_after_passes,
1321+
exir_ops.edge.aten.gelu.default,
1322+
),
1323+
0,
1324+
)
1325+
1326+
# The decomposition should have one tanh, 2 add and 6 mul
1327+
self.assertEqual(
1328+
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1329+
1,
1330+
)
1331+
self.assertEqual(
1332+
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
1333+
2,
1334+
)
1335+
self.assertEqual(
1336+
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
1337+
6,
1338+
)
1339+
13041340

13051341
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
13061342
def test_no_replacement_for_conv(self):

backends/xnnpack/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,6 @@ create an issue on [github](https://www.github.com/pytorch/executorch/issues).
131131

132132

133133
## See Also
134-
For more information about the XNNPACK Delegate, please check out the following resources:
135-
- [ExecuTorch XNNPACK Delegate](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html)
136-
- [Building and Running ExecuTorch with XNNPACK Backend](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html)
134+
For more information about the XNNPACK Backend, please check out the following resources:
135+
- [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack.html)
136+
- [XNNPACK Backend Internals](https://pytorch.org/executorch/main/backend-delegates-xnnpack-reference.html)

docs/source/getting-started.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ ExecuTorch provides hardware acceleration for a wide variety of hardware. The mo
4343
For mobile use cases, consider using XNNPACK for Android and Core ML or XNNPACK for iOS as a first step. See [Hardware Backends](backends-overview.md) for more information.
4444

4545
### Exporting
46-
Exporting is done using Python APIs. ExecuTorch provides a high degree of customization during the export process, but the typical flow is as follows. This example uses the MobileNet V2 image classification model implementation in torchvision, but the process supports any [export-compliant](https://pytorch.org/docs/stable/export.html) PyTorch model.
46+
Exporting is done using Python APIs. ExecuTorch provides a high degree of customization during the export process, but the typical flow is as follows. This example uses the MobileNet V2 image classification model implementation in torchvision, but the process supports any [export-compliant](https://pytorch.org/docs/stable/export.html) PyTorch model. For users working with Hugging Face models,
47+
you can find a list of supported models in the [*huggingface/optimum-executorch*](https://github.com/huggingface/optimum-executorch) repo.
4748

4849
```python
4950
import torch
@@ -101,6 +102,8 @@ print(torch.allclose(output[0], eager_reference_output, rtol=1e-3, atol=1e-5))
101102

102103
For complete examples of exporting and running the model, please refer to our [examples GitHub repository](https://github.com/pytorch-labs/executorch-examples/tree/main/mv2/python).
103104

105+
Additionally, if you work with Hugging Face models, the [*huggingface/optimum-executorch*](https://github.com/huggingface/optimum-executorch) library simplifies running these models end-to-end with ExecuTorch, using familiar Hugging Face APIs. Visit the repository for specific examples and supported models.
106+
104107
<hr/>
105108

106109
## Running on Device

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ExecuTorch provides support for:
4343
#### Examples
4444
- [Android Demo Apps](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app)
4545
- [iOS Demo Apps](demo-apps-ios.md)
46+
- [Hugging Face Models](https://github.com/huggingface/optimum-executorch/blob/main/README.md)
4647
#### Backends
4748
- [Overview](backends-overview)
4849
- [XNNPACK](backends-xnnpack)

docs/source/using-executorch-android.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,20 @@ You can also directly specify an AAR file in the app. We upload pre-built AAR to
5858

5959
### Snapshots from main branch
6060

61-
| Date | AAR | SHASUMS |
62-
| ------- | --- | ------- |
63-
| 2025-02-27 | [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-20250227/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-20250227/executorch.aar.sha256sums) |
61+
Starting from 2025-04-12, you can download nightly `main` branch snapshots:
62+
* `executorch.aar`: `https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-YYYYMMDD/executorch.aar`
63+
* `executorch.aar.sha256sums`: `https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-YYYYMMDD/executorch.aar.sha256sums`
64+
* Replace `YYYYMMDD` with the actual date you want to use.
65+
* AAR file is generated by [this workflow](https://github.com/pytorch/executorch/blob/c66b37d010c88a113560693b14dc6bd112593c11/.github/workflows/android-release-artifacts.yml#L14-L15).
66+
67+
For example:
68+
69+
```sh
70+
curl -O https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-20250412/executorch.aar
71+
curl -O https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-20250412/executorch.aar.sha256sums
72+
```
73+
74+
We aim to make every daily snapshot available and useable. However, for best stability, please use releases, not snapshots.
6475

6576
## Using AAR file
6677

examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Link your binary with the ExecuTorch runtime and any backends or kernels used by
8585

8686
Note: To access logs, link against the Debug build of the ExecuTorch runtime, i.e., the executorch_debug framework. For optimal performance, always link against the Release version of the deliverables (those without the _debug suffix), which have all logging overhead removed.
8787

88-
For more details integrating and Running ExecuTorch on Apple Platforms, checkout this [link](https://pytorch.org/executorch/using-executorch-ios.html).
88+
For more details integrating and Running ExecuTorch on Apple Platforms, checkout this [link](https://pytorch.org/executorch/main/using-executorch-ios.html).
8989

9090
<p align="center">
9191
<img src="https://raw.githubusercontent.com/pytorch/executorch/refs/heads/main/docs/source/_static/img/ios_demo_app_swift_pm.png" alt="iOS LLaMA App Swift PM" style="width:600px">

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ inline bool check_dim_order(
133133
size_t gauss_sum = 0;
134134
std::vector<int> count(dim_order.size(), 0);
135135
for (int i = 0; i < dim_order.size(); i++) {
136-
if (dim_order[i] < 0 || dim_order[i] >= sizes.size()) {
136+
if (dim_order[i] >= sizes.size()) {
137137
return false;
138138
}
139139
gauss_sum += static_cast<size_t>(dim_order[i]) + 1;

test/utils/DeathTest.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
#include <gtest/gtest.h>
1717

18+
#ifndef ET_BUILD_MODE_COV
19+
#define ET_BUILD_MODE_COV 0
20+
#endif // ET_BUILD_MODE_COV
21+
1822
#if ET_BUILD_MODE_COV
1923

2024
/**

0 commit comments

Comments
 (0)