Skip to content

Commit

Permalink
fix cuda test file
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Dec 3, 2024
1 parent 5245163 commit 6f58e43
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_linux_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m pip install jax[cuda]
python -m pip install jax[cuda12]
python -m pip install nvidia-nccl-cu12
- name: Run pytest
Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def test_scvi_num_workers():
model.get_normalized_expression(n_samples=2)


def test_scvi_train_ddp(save_path: str):
def test_scvi_train_ddp(save_path: str = "."):
training_code = """
import torch
import scvi
Expand All @@ -1312,7 +1312,7 @@ def test_scvi_train_ddp(save_path: str):
model = SCVI(adata)
model.train(
max_epochs=100,
max_epochs=1,
check_val_every_n_epoch=1,
accelerator="gpu",
devices=-1,
Expand Down

0 comments on commit 6f58e43

Please sign in to comment.