Skip to content

Commit

Permalink
Fix missing transfer to device in ProjectedGradientDescentPyTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Jan 16, 2025
1 parent fd63952 commit d794c1f
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,10 @@ def _projection(
if (suboptimal or norm == 2) and norm != np.inf: # Simple rescaling
values_norm = torch.linalg.norm(values_tmp, ord=norm, dim=1, keepdim=True) # (n_samples, 1)
values_tmp = values_tmp * values_norm.where(
values_norm == 0, torch.minimum(torch.ones(1), torch.tensor(eps).to(values_tmp.device) / values_norm)
values_norm == 0,
torch.minimum(
torch.ones(1).to(values_tmp.device), torch.tensor(eps).to(values_tmp.device) / values_norm
),
)
else: # Optimal
if norm == np.inf: # Easy exact case
Expand Down

0 comments on commit d794c1f

Please sign in to comment.