-
Notifications
You must be signed in to change notification settings - Fork 26k
Error out on in-place (unary) ops on tensors that have internal overlap #17927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Works for CPU and CUDA on the following ops: - abs_, acos_, asin_, atan_, ceil_, cos_, erf_, erfc_, exp_, expm1_ - floor_, log_, log10_, log1p_, log2_, round_, rsqrt_, sigmoid_, - sin_, sqrt_, tan_, tanh_, trunc_ This PR adds a check to see if the out/result tensor has internal overlap. If it does, then we error out because the result **may** be incorrect. This is overly conservative; there are some cases where if the result is the same as the input, the inplace operation is OK (such as floor_, round_, and trunc_). However, the current code isn't organized in such a way that this is easy to check, so enabling those will come in the future. Test Plan: New tests in test/test_torch.py Future: I have a list of other in-place unary ops that need this check.
soumith
approved these changes
Mar 13, 2019
facebook-github-bot
pushed a commit
that referenced
this pull request
Mar 15, 2019
Summary: Pull Request resolved: #17926 ghimport-source-id: 9f7572b Stack: * **#17926 Implement at::has_internal_overlap helper function** * #17927 Error out on in-place (unary) ops on tensors that have internal overlap On the way to #17935. Checks if a tensor's sizes/strides indicate that multiple elements share the same memory location. This problem in general is hard so at::has_internal_overlap implements two heuristics and avoids solving the general problem: if a tensor is contiguous, it cannot have internal overlap if a tensor has any zero strides, it does have internal overlap otherwise, return MemOverlap::kTooHard to indicate that there might be overlap, but we don't know. Reviewed By: ezyang Differential Revision: D14438858 fbshipit-source-id: 607ab31771315921ab6165b2a1f072ac3e75925a
facebook-github-bot
pushed a commit
that referenced
this pull request
Mar 15, 2019
Summary: Pull Request resolved: #17993 ghimport-source-id: 5427773 Stack: * #17926 Implement at::has_internal_overlap helper function * #17927 Error out on in-place (unary) ops on tensors that have internal overlap * **#17993 [easy] Delete dead code in THTensorMoreMath.cpp** We seem to have new implementations already for these in ATen. Reviewed By: ezyang Differential Revision: D14457838 fbshipit-source-id: 8481aad74b2127bd28c0f3e09740889fc0488a31
zdevito
pushed a commit
to zdevito/ATen
that referenced
this pull request
Mar 15, 2019
…ap (#17927) Summary: Pull Request resolved: pytorch/pytorch#17927 ghimport-source-id: 626d321e430b6b5c0ea3aa1eb9df8c1e2d058bf8 Stack: * #17926 Implement at::has_internal_overlap helper function * **#17927 Error out on in-place (unary) ops on tensors that have internal overlap** On the way to #17935. Works for CPU and CUDA on the following ops: - abs_, acos_, asin_, atan_, ceil_, cos_, erf_, erfc_, exp_, expm1_ - floor_, log_, log10_, log1p_, log2_, round_, rsqrt_, - sin_, sqrt_, tan_, tanh_, trunc_ This PR adds a check to see if the out/result tensor has internal overlap. If it does, then we error out because the result **may** be incorrect. This is overly conservative; there are some cases where if the result is the same as the input, the inplace operation is OK (such as floor_, round_, and trunc_). However, the current code isn't organized in such a way that this is easy to check, so enabling those will come in the future. Reviewed By: ezyang Differential Revision: D14438871 fbshipit-source-id: 15e12bf1fdb2ab7f74bb806e22bc74840bd6abd1
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack:
On the way to #17935.
Works for CPU and CUDA on the following ops:
This PR adds a check to see if the out/result tensor has internal
overlap. If it does, then we error out because the result may be
incorrect.
This is overly conservative; there are some cases where if the result is
the same as the input, the inplace operation is OK (such as floor_,
round_, and trunc_). However, the current code isn't organized in such a
way that this is easy to check, so enabling those will come in the future.
Test Plan:
New tests in test/test_torch.py
Future:
I have a list of other in-place unary ops that need this check.
gh-metadata: pytorch pytorch 17927 gh/zou3519/23/head
Differential Revision: D14438871