Skip to content

Conversation

@nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Dec 9, 2020

Fixes #44378 by providing a wider range of drivers similar to what SciPy is doing.

The supported CPU drivers are gels, gelsy, gelsd, gelss.
The CUDA interface has only gels implemented but only for overdetermined systems.

The current state of this PR:

@nikitaved nikitaved requested a review from glaringlee as a code owner December 9, 2020 16:07
@nikitaved nikitaved added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Dec 9, 2020
@dr-ci
Copy link

dr-ci bot commented Dec 9, 2020

💊 CI failures summary and remediations

As of commit 6064f79 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@nikitaved nikitaved force-pushed the nikved/linalg_lstsq branch from 056961b to 927e084 Compare December 9, 2020 17:31
@xiaosu-zhu
Copy link

Is there a chance to migrate gels* from LAPACK to CUDA, without MAGMA? I'm not a pro in cpp or CUDA programming, but I 'm just curious about if there exists equivalent operations in CUDA that implement gels*.

@nikitaved
Copy link
Collaborator Author

@xiaosu-zhu , implementing these things from scratch is no trivial matter, and might be beyond the scope of this PR. Maybe it is possible to contact the MAGMA people and ask them whether there is a plan to support these functions in near future...

@xiaosu-zhu
Copy link

@xiaosu-zhu , implementing these things from scratch is no trivial matter, and might be beyond the scope of this PR. Maybe it is possible to contact the MAGMA people and ask them whether there is a plan to support these functions in near future...

Yeah, thank you @nikitaved. I have anticipated that it would be a hard work, may be leave it for the future 😃

@nikitaved nikitaved changed the title [WIP] Implements torch.linalg.lstsq Implements torch.linalg.lstsq Jan 3, 2021
@nikitaved
Copy link
Collaborator Author

This PR is ready for review.

@nikitaved nikitaved requested review from mruberry and removed request for glaringlee January 3, 2021 15:33
@codecov
Copy link

codecov bot commented Jan 3, 2021

Codecov Report

Merging #49093 (6064f79) into master (1772e26) will increase coverage by 0.02%.
The diff coverage is 92.93%.

@@            Coverage Diff             @@
##           master   #49093      +/-   ##
==========================================
+ Coverage   77.30%   77.32%   +0.02%     
==========================================
  Files        1888     1888              
  Lines      183589   183858     +269     
==========================================
+ Hits       141925   142176     +251     
- Misses      41664    41682      +18     

@mruberry mruberry requested a review from IvanYashchuk January 4, 2021 16:21
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 5, 2021
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks clean to me. I left a few questions and suggestions inline.
The idea of implementing batch_iterator_with_broadcasting is great. It will be certainly useful for other "find x s.t. A x = b" functions.

Since LAPACK 3.7.0 (released in Dec 2016) there is also GETSLS driver available (mkl docs, netlib release notes). Something to consider to add in follow-up work.
A note for the future: while there are no specialized drivers in cuBLAS, cuSOLVER, MAGMA for m < n case, we can implement it ourselves using SVD u @ s⁻¹ @ vᴴ @ b.

``'gelsy'`` is the fastest among the rank-revealing algorithms that also handles rank-deficient inputs.
Returns:
(Tensor, Tensor, Tensor): a namedtuple (x, rank, s) containing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more clear to have a namedtuple (solution, rank, singular_values)?
cc: @mruberry

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the same way, but the inspiration was taken from SciPy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy.linalg function don't return a namedtuple, so that shouldn't matter here. I'd say go for the clearer names.

@mruberry
Copy link
Collaborator

@nikitaved @IvanYashchuk ping me when you're both happy with this PR.

@nikitaved
Copy link
Collaborator Author

I am basically done. I am not sure whether it is actually a good idea to use pseudoinve for cases m < n, for two reasons:

  1. It is slow.
  2. lstsq is not differentiable, while pseudoinverse is.

Comment on lines 8914 to 9031
- func: _lstsq_helper(Tensor a, Tensor b, float cond, str? driver_name) -> (Tensor, Tensor, Tensor)
use_c10_dispatcher: full
variants: function
dispatch:
CPU: _lstsq_helper_cpu
CUDA: _lstsq_helper_cuda
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it makes sense to remove this from native functions and use the declared/defined dispatch instead? CC @mruberry

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's a great idea.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks very good! So I'm happy with this PR.

@nikitaved nikitaved mentioned this pull request Jan 13, 2021
14 tasks
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks very good to me.
There is a minor suggestion for the tests.

I've noticed one thing we didn't discuss is broadcasting for the case when either a or b is batched while the other is not. Currently it's not allowed in this PR

RuntimeError: torch.linalg.lstsq: self.dim() must be greater or equal to b.dim() and (self.dim() - b.dim()) <= 1

NumPy's lstsq does not support batched input at all. But for example numpy.linalg.solve supports it and allows the following

b = torch.randn(5, 1)
a = torch.randn(2,1,3, 5, 5)
np.linalg.solve(a, b)

I think this should be discussed separately and then the same behaviour applied to all A x = b functions.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Feb 24, 2021

@IvanYashchuk , we have already linalg.solve, right? How does it handle broadcasting? Your example is fully broadcastable, and, since we mention broadcasting in the documentation, it makes sense to support the broadcasting semantics in full, so, I agree, we need to fix that.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Feb 26, 2021

@mruberry , with @IvanYashchuk we decided to create an issue to discuss the broadcasting situation for the solve-like methods. Once we agree on the interface, we could adapt it for linalg.lstsq in a follow-up PR.

@mruberry
Copy link
Collaborator

mruberry commented Feb 26, 2021

I think this PR is unblocked, and indeed, the todos could be done in separate PRs.

Awesome!

Regarding the value for cond, I think we could keep the current behavior, so that the user could use lstsq without worrying about its value. I could update the doc indicating that the default value is subject to potential changes, hence it makes sense to specify cond explicitly to guarantee the non-bc-breaking behavior.

Sounds great. Do you just want to make this tweak by adding a sentence to the description of the cond arg in the doc and resolve the conflict with torch/overrides.py and we'll merge this?

Edit: actually, the ROCm test failures might be real.

15:07:16 ERROR [9.679s]: test_fn_grad_linalg_solve_cuda_complex128 (__main__.TestGradientsCUDA)

Let's see if rebasing fixes them or try to identify if these is an issue.

@mruberry
Copy link
Collaborator

mruberry commented Mar 2, 2021

fyi I expect this will land today; internal tooling needed some help with it.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 3ac9013.

@mruberry
Copy link
Collaborator

mruberry commented Mar 3, 2021

FYI going to revert this due to an internal build issue. No action needed, I just need to resolve it internally.

Update: an internal project had a build issue because they're consuming LAPACK functions from multiple dependencies. They are working on resolving the issue now, and will validate their fix allows us to land this change without breaking them. There's no action on our part at this time.

@mruberry mruberry reopened this Mar 3, 2021
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

Arg. This hit a logical merge conflict with the update from "supports_tensor_out" -> "supports_out." I updated it, but the lint build is going to fail because we now check for trailing whitespace, and there's a ton of trailing whitespace in this PR's base. @nikitaved, would you please rebase this? The internal build appears to be OK now so after a rebase we should be OK to land.

@nikitaved nikitaved force-pushed the nikved/linalg_lstsq branch 3 times, most recently from b0ca309 to 4684f6d Compare March 12, 2021 09:39
@nikitaved nikitaved force-pushed the nikved/linalg_lstsq branch from 4684f6d to 6064f79 Compare March 12, 2021 09:43
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
Summary:
Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing.

The supported CPU drivers are `gels, gelsy, gelsd, gelss`.
The CUDA interface has only `gels` implemented but only for overdetermined systems.

The current state of this PR:
- [x] CPU interface
- [x] CUDA interface
- [x] CPU tests
- [x] CUDA tests
- [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252
- [x] docs

Pull Request resolved: pytorch#49093

Reviewed By: H-Huang

Differential Revision: D26723384

Pulled By: mruberry

fbshipit-source-id: c9866a95f14091955cf42de22f4ac9e2da009713
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing.

The supported CPU drivers are `gels, gelsy, gelsd, gelss`.
The CUDA interface has only `gels` implemented but only for overdetermined systems.

The current state of this PR:
- [x] CPU interface
- [x] CUDA interface
- [x] CPU tests
- [x] CUDA tests
- [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252
- [x] docs

Pull Request resolved: pytorch#49093

Reviewed By: H-Huang

Differential Revision: D26723384

Pulled By: mruberry

fbshipit-source-id: c9866a95f14091955cf42de22f4ac9e2da009713
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
Fixes pytorch#44378 by providing a wider range of drivers similar to what SciPy is doing.

The supported CPU drivers are `gels, gelsy, gelsd, gelss`.
The CUDA interface has only `gels` implemented but only for overdetermined systems.

The current state of this PR:
- [x] CPU interface
- [x] CUDA interface
- [x] CPU tests
- [x] CUDA tests
- [x] Memory-efficient batch-wise iteration with broadcasting which fixes pytorch#49252
- [x] docs

Pull Request resolved: pytorch#49093

Reviewed By: albanD

Differential Revision: D26991788

Pulled By: mruberry

fbshipit-source-id: 8af9ada979240b255402f55210c0af1cba6a0a3c

check_if_copy_needed_for_a(a_curr_linear_batch_idx);

auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it guaranteed that a_3d is contiguous? Otherwise just getting data_ptr might not be safe.

In comments it says that the input is expected to be "almost contiguous" but I can't find where it's enforced.

Btw, it might be marginally faster to create TensorAccessor once instead of doing repeated .select() (https://pytorch.org/cppdocs/notes/tensor_basics.html#efficient-access-to-tensor-elements)

Copy link
Collaborator Author

@nikitaved nikitaved May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look at the long comment right above, it mentions that a and b are expected to be "contiguous" (wrt to the batch dimensions) and in column-major order wrt to the last two dimensions, i.e. the output of a Lapack routine is sufficient. No enforcing is done.

Copy link
Collaborator Author

@nikitaved nikitaved May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor accessor will not work for CUDA tensors, or will it? Or we would need to write a separate kernel for that, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Can't solve torch.lstsq() with specific values