Skip to content

Commit

Permalink
fixed ValueError cause of negative strides
Browse files Browse the repository at this point in the history
  • Loading branch information
uwimt committed Dec 8, 2024
1 parent 7b2855d commit 4df273a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/mokka/mapping/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_constellation(self, *args):
"""
# Test bits
B = generate_all_bits(self.m.item()).copy()
bits = torch.from_numpy(B).to(self.weights.device)
bits = torch.from_numpy(B.copy()).to(self.weights.device)
logger.debug("bits device: %s", bits.device)
out = self.forward(bits)
return out
Expand Down Expand Up @@ -304,7 +304,7 @@ def get_constellation(self, *args):
mod_args = torch.tensor(args, dtype=torch.float32)
mod_args = mod_args.repeat(2 ** self.m.item(), 1).split(1, dim=-1)
B = generate_all_bits(self.m.item()).copy()
bits = torch.from_numpy(B).to(self.map1.weight.device)
bits = torch.from_numpy(B.copy()).to(self.map1.weight.device)
logger.debug("bits device: %s", bits.device)
out = self.forward(bits, *mod_args).flatten()
return out
Expand Down Expand Up @@ -413,7 +413,7 @@ def get_constellation(self, *args):
:returns: tensor of constellation points
"""
# Test bits
bits = torch.from_numpy(generate_all_bits(self.m.item())).to(
bits = torch.from_numpy(generate_all_bits(self.m.item()).copy()).to(
self.real_weights.device
)
logger.debug("bits device: %s", bits.device)
Expand Down

0 comments on commit 4df273a

Please sign in to comment.