Skip to content

Commit

Permalink
ci: manual backport MPS fix for 1.2.x before release (#3116)
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis authored Dec 31, 2024
1 parent bd2262c commit bc141f3
Show file tree
Hide file tree
Showing 21 changed files with 350 additions and 109 deletions.
77 changes: 77 additions & 0 deletions .github/workflows/test_linux_mps.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: test (mps)

on:
push:
branches: [main, "[0-9]+.[0-9]+.x"]
pull_request:
branches: [main, "[0-9]+.[0-9]+.x"]
types: [labeled, synchronize, opened]
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
# if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered
if: >-
(
contains(github.event.pull_request.labels.*.name, 'mps') ||
contains(github.event.pull_request.labels.*.name, 'all tests') ||
contains(github.event_name, 'schedule') ||
contains(github.event_name, 'workflow_dispatch')
)
runs-on: [self-hosted, macOS, X64, MPS]

name: macos_integration

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
#- name: Get the current branch name
# id: vars
# run: echo "BRANCH_NAME=$(echo $GITHUB_REF | awk -F'/' '{print $3}')" >> $GITHUB_ENV

- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"

- name: Create Conda environment and install dependencies
run: |
conda init bash
source ~/.bash_profile
conda activate scvi
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel uv
python -m pip install "scvi-tools[tests]"
python -m pip install jax-metal
python -m pip install coverage
python -m pip install pytest
- name: Run pytest
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
COLUMNS: 120
PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0
PYTORCH_ENABLE_MPS_FALLBACK: 1.0
run: |
coverage run -m pytest -v --color=yes --accelerator mps --devices auto
coverage report
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
3 changes: 3 additions & 0 deletions .github/workflows/test_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ jobs:
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m pip install jax-metal
- name: Run pytest
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
COLUMNS: 120
PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0
PYTORCH_ENABLE_MPS_FALLBACK: 1.0
run: |
coverage run -m pytest -v --color=yes
coverage report
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ to [Semantic Versioning]. Full commit history is available in the

## Version 1.2

### 1.2.2 (2024-XX-XX)
### 1.2.2 (2024-12-31)

#### Added

- Add MuData Minification option to {class}`~scvi.model.TOTALVI` {pr}`3061`.
- Add Support for MPS usage in mac {pr}`3100`.
- Add support for torch.compile before train (EXPERIMENTAL) {pr}`2931`.
- Add support for Numpy 2.0 {pr}`2842`.
- Changed scvi-hub ModelCard and add criticism metrics to the card {pr}`3078`.
- MuData support for {class}`~scvi.model.MULTIVI` via the method
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["hatchling"]

[project]
name = "scvi-tools"
version = "1.2.1"
version = "1.2.2"
description = "Deep probabilistic analysis of single-cell omics data."
readme = "README.md"
requires-python = ">=3.10"
Expand Down Expand Up @@ -62,6 +62,7 @@ editing = ["jupyter", "pre-commit"]
dev = ["scvi-tools[editing,tests]"]
test = ["scvi-tools[tests]"]
cuda = ["torchvision","torchaudio","jax[cuda]"]
metal = ["torchvision","torchaudio","jax-metal"]

docs = [
"docutils>=0.8,!=0.18.*,!=0.19.*", # see https://github.com/scverse/cookiecutter-scverse/pull/205
Expand Down Expand Up @@ -111,7 +112,7 @@ tutorials = [
"squidpy",
]

all = ["scvi-tools[cuda,dev,docs,tutorials]"]
all = ["scvi-tools[dev,docs,tutorials]"]

[tool.hatch.build.targets.wheel]
packages = ['src/scvi']
Expand Down
19 changes: 13 additions & 6 deletions src/scvi/data/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def poisson_gene_selection(

# Calculate empirical statistics.
sum_0 = np.asarray(data.sum(0)).ravel()
scaled_means = torch.from_numpy(sum_0 / sum_0.sum()).to(device)
total_counts = torch.from_numpy(np.asarray(data.sum(1)).ravel()).to(device)

# in MPS we need to first change to float 32, as the MPS framework doesn't support float64.
# We will thus do it by default for all accelerators
scaled_means = torch.from_numpy(np.float32(sum_0 / sum_0.sum())).to(device)
observed_fraction_zeros = torch.from_numpy(
np.asarray(1.0 - (data > 0).sum(0) / data.shape[0]).ravel()
np.float32(np.asarray(1.0 - (data > 0).sum(0) / data.shape[0]).ravel())
).to(device)

# Calculate probability of zero for a Poisson model.
Expand All @@ -151,8 +152,13 @@ def poisson_gene_selection(
expected_fraction_zeros /= data.shape[0]

# Compute probability of enriched zeros through sampling from Binomial distributions.
observed_zero = torch.distributions.Binomial(probs=observed_fraction_zeros)
expected_zero = torch.distributions.Binomial(probs=expected_fraction_zeros)
# TODO: TORCH MPS FIX - 'aten::binomial' is not currently implemented for the MPS device
if device.type == "mps":
observed_zero = torch.distributions.Binomial(probs=observed_fraction_zeros.to("cpu"))
expected_zero = torch.distributions.Binomial(probs=expected_fraction_zeros.to("cpu"))
else:
observed_zero = torch.distributions.Binomial(probs=observed_fraction_zeros)
expected_zero = torch.distributions.Binomial(probs=expected_fraction_zeros)

extra_zeros = torch.zeros(expected_fraction_zeros.shape, device=device)
for _ in track(
Expand All @@ -161,7 +167,8 @@ def poisson_gene_selection(
disable=silent,
style="tqdm", # do not change
):
extra_zeros += observed_zero.sample() > expected_zero.sample()
obs_exp_bool_mat = observed_zero.sample() > expected_zero.sample()
extra_zeros += obs_exp_bool_mat.to("mps") if device.type == "mps" else obs_exp_bool_mat

prob_zero_enrichment = (extra_zeros / n_samples).cpu().numpy()

Expand Down
Loading

0 comments on commit bc141f3

Please sign in to comment.