-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implements torch.linalg.lstsq
#49093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 6064f79 (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
056961b to
927e084
Compare
|
Is there a chance to migrate |
|
@xiaosu-zhu , implementing these things from scratch is no trivial matter, and might be beyond the scope of this PR. Maybe it is possible to contact the MAGMA people and ask them whether there is a plan to support these functions in near future... |
Yeah, thank you @nikitaved. I have anticipated that it would be a hard work, may be leave it for the future 😃 |
torch.linalg.lstsqtorch.linalg.lstsq
|
This PR is ready for review. |
Codecov Report
@@ Coverage Diff @@
## master #49093 +/- ##
==========================================
+ Coverage 77.30% 77.32% +0.02%
==========================================
Files 1888 1888
Lines 183589 183858 +269
==========================================
+ Hits 141925 142176 +251
- Misses 41664 41682 +18 |
IvanYashchuk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks clean to me. I left a few questions and suggestions inline.
The idea of implementing batch_iterator_with_broadcasting is great. It will be certainly useful for other "find x s.t. A x = b" functions.
Since LAPACK 3.7.0 (released in Dec 2016) there is also GETSLS driver available (mkl docs, netlib release notes). Something to consider to add in follow-up work.
A note for the future: while there are no specialized drivers in cuBLAS, cuSOLVER, MAGMA for m < n case, we can implement it ourselves using SVD u @ s⁻¹ @ vᴴ @ b.
torch/linalg/__init__.py
Outdated
| ``'gelsy'`` is the fastest among the rank-revealing algorithms that also handles rank-deficient inputs. | ||
| Returns: | ||
| (Tensor, Tensor, Tensor): a namedtuple (x, rank, s) containing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more clear to have a namedtuple (solution, rank, singular_values)?
cc: @mruberry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the same way, but the inspiration was taken from SciPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scipy.linalg function don't return a namedtuple, so that shouldn't matter here. I'd say go for the clearer names.
|
@nikitaved @IvanYashchuk ping me when you're both happy with this PR. |
|
I am basically done. I am not sure whether it is actually a good idea to use pseudoinve for cases m < n, for two reasons:
|
| - func: _lstsq_helper(Tensor a, Tensor b, float cond, str? driver_name) -> (Tensor, Tensor, Tensor) | ||
| use_c10_dispatcher: full | ||
| variants: function | ||
| dispatch: | ||
| CPU: _lstsq_helper_cpu | ||
| CUDA: _lstsq_helper_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it makes sense to remove this from native functions and use the declared/defined dispatch instead? CC @mruberry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's a great idea.
IvanYashchuk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks very good! So I'm happy with this PR.
IvanYashchuk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks very good to me.
There is a minor suggestion for the tests.
I've noticed one thing we didn't discuss is broadcasting for the case when either a or b is batched while the other is not. Currently it's not allowed in this PR
RuntimeError: torch.linalg.lstsq: self.dim() must be greater or equal to b.dim() and (self.dim() - b.dim()) <= 1NumPy's lstsq does not support batched input at all. But for example numpy.linalg.solve supports it and allows the following
b = torch.randn(5, 1)
a = torch.randn(2,1,3, 5, 5)
np.linalg.solve(a, b)I think this should be discussed separately and then the same behaviour applied to all A x = b functions.
|
@IvanYashchuk , we have already |
|
@mruberry , with @IvanYashchuk we decided to create an issue to discuss the broadcasting situation for the solve-like methods. Once we agree on the interface, we could adapt it for |
Awesome!
Sounds great. Do you just want to make this tweak by adding a sentence to the description of the cond arg in the doc and resolve the conflict with torch/overrides.py and we'll merge this? Edit: actually, the ROCm test failures might be real. Let's see if rebasing fixes them or try to identify if these is an issue. |
|
fyi I expect this will land today; internal tooling needed some help with it. |
|
FYI going to revert this due to an internal build issue. No action needed, I just need to resolve it internally. Update: an internal project had a build issue because they're consuming LAPACK functions from multiple dependencies. They are working on resolving the issue now, and will validate their fix allows us to land this change without breaking them. There's no action on our part at this time. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Arg. This hit a logical merge conflict with the update from "supports_tensor_out" -> "supports_out." I updated it, but the lint build is going to fail because we now check for trailing whitespace, and there's a ton of trailing whitespace in this PR's base. @nikitaved, would you please rebase this? The internal build appears to be OK now so after a rebase we should be OK to land. |
b0ca309 to
4684f6d
Compare
4684f6d to
6064f79
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing. The supported CPU drivers are `gels, gelsy, gelsd, gelss`. The CUDA interface has only `gels` implemented but only for overdetermined systems. The current state of this PR: - [x] CPU interface - [x] CUDA interface - [x] CPU tests - [x] CUDA tests - [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252 - [x] docs Pull Request resolved: pytorch#49093 Reviewed By: H-Huang Differential Revision: D26723384 Pulled By: mruberry fbshipit-source-id: c9866a95f14091955cf42de22f4ac9e2da009713
Summary: Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing. The supported CPU drivers are `gels, gelsy, gelsd, gelss`. The CUDA interface has only `gels` implemented but only for overdetermined systems. The current state of this PR: - [x] CPU interface - [x] CUDA interface - [x] CPU tests - [x] CUDA tests - [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252 - [x] docs Pull Request resolved: pytorch#49093 Reviewed By: H-Huang Differential Revision: D26723384 Pulled By: mruberry fbshipit-source-id: c9866a95f14091955cf42de22f4ac9e2da009713
Summary: Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing. The supported CPU drivers are `gels, gelsy, gelsd, gelss`. The CUDA interface has only `gels` implemented but only for overdetermined systems. The current state of this PR: - [x] CPU interface - [x] CUDA interface - [x] CPU tests - [x] CUDA tests - [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252 - [x] docs Pull Request resolved: pytorch#49093 Reviewed By: albanD Differential Revision: D26991788 Pulled By: mruberry fbshipit-source-id: 8af9ada979240b255402f55210c0af1cba6a0a3c
|
|
||
| check_if_copy_needed_for_a(a_curr_linear_batch_idx); | ||
|
|
||
| auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it guaranteed that a_3d is contiguous? Otherwise just getting data_ptr might not be safe.
In comments it says that the input is expected to be "almost contiguous" but I can't find where it's enforced.
Btw, it might be marginally faster to create TensorAccessor once instead of doing repeated .select() (https://pytorch.org/cppdocs/notes/tensor_basics.html#efficient-access-to-tensor-elements)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have a look at the long comment right above, it mentions that a and b are expected to be "contiguous" (wrt to the batch dimensions) and in column-major order wrt to the last two dimensions, i.e. the output of a Lapack routine is sufficient. No enforcing is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor accessor will not work for CUDA tensors, or will it? Or we would need to write a separate kernel for that, right?
Fixes #44378 by providing a wider range of drivers similar to what SciPy is doing.
The supported CPU drivers are
gels, gelsy, gelsd, gelss.The CUDA interface has only
gelsimplemented but only for overdetermined systems.The current state of this PR: