Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch._C._LinAlgError #4

Open
fatginger1024 opened this issue Jul 19, 2023 · 2 comments
Open

torch._C._LinAlgError #4

fatginger1024 opened this issue Jul 19, 2023 · 2 comments

Comments

@fatginger1024
Copy link

  • problem description
    torch._C._LinAlgError encountered after a few terations
  • proposed solution
    replacing torch.linalg.solve by torch.linalg.lstsq
@cnexah
Copy link
Owner

cnexah commented Jul 19, 2023

I didn't encounter the problem before. I think your solution fine but might be slower.
Or you can try to slightly increase the value of jitter

@MarioBgle
Copy link

MarioBgle commented Feb 21, 2024

Hi, I found the solution to the problem after spending some time there. Let me know if I did a mistake.
torch.linalg.solve works for recent pytorch versions, you need to switch arguments.
Pytorch removed the old torch.solve(A,B) and torch.lstsq because its behaviour was different to numpys, see this stackoverflow post.
If you try to replicate the example from this stackoverflow link with torch.linalg, you will get the same result as in numpy.

So, I found that for newer pytorch versions this is the correct way:
x = torch.linalg.solve(ATA + jitter, ATB) # NOT torch.linalg.solve(ATB, ATA + jitter)
x = x.reshape(n, self.gr, h, w)

I did come to the same conclusion as @cnexah that torch.linalg.lstsq is actually just slower.

  jitter = torch.eye(n=h * w, dtype=x.dtype, device=x.device).unsqueeze(0) * 1e-12  # Torch:(1, 1672, 1672)
  start_time_solve_linalg = time.time()
  x1 = torch.linalg.solve(ATA + jitter, ATB)
  x1 = x1.reshape(n, self.gr, h, w)  # Torch: (32, 1672, 1)
  end_time_solve_linalg = time.time()
  start_time_lstsq_linalg = time.time()
   x = torch.linalg.lstsq(ATA + jitter, ATB)
   x = x.solution.reshape(n, self.gr, h, w) # x.solution is 32,1,1672
   end_time_lstsq_linalg = time.time()
   difference = torch.norm(x1 - x)
   print(f"Time taken for linalg.solve: {end_time_solve_linalg - start_time_solve_linalg}, time taken"
                f" for linalg.lstsq: {end_time_lstsq_linalg - start_time_lstsq_linalg}"
                f"the difference between both tensors is {difference}")

An example output is: Time taken for linalg.solve: 0.40329957008361816, time taken for linalg.lstsq: 0.8831644058227539the difference between both tensors is 0.05662015080451965

Hope this helps future people :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants