Skip to content

Commit

Permalink
Merge pull request #2500 from Trusted-AI/development_issue_2473
Browse files Browse the repository at this point in the history
Apply package.version.parse
  • Loading branch information
beat-buesser authored Oct 1, 2024
2 parents 1207d0a + 934872a commit 8738a5a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import logging
import math
from packaging.version import parse
from typing import Any, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -121,8 +122,8 @@ def __init__(
import torch
import torchvision

torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torch_version = list(parse(torch.__version__.lower()).release)
torchvision_version = list(parse(torchvision.__version__.lower()).release)

Check warning on line 126 in art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

View check run for this annotation

Codecov / codecov/patch

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py#L125-L126

Added lines #L125 - L126 were not covered by tests
assert (
torch_version[0] >= 1 and torch_version[1] >= 7 or (torch_version[0] >= 2)
), "AdversarialPatchPyTorch requires torch>=1.7.0"
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/pixel_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import logging
from itertools import product
from packaging.version import parse
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -42,7 +43,7 @@
import scipy
from scipy._lib._util import check_random_state

scipy_version = list(map(int, scipy.__version__.lower().split(".")))
scipy_version = list(parse(scipy.__version__.lower()).release)
if scipy_version[1] >= 8:
from scipy.optimize._optimize import _status_message
else:
Expand Down
5 changes: 3 additions & 2 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import logging
from packaging.version import parse
from typing import Any, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -96,8 +97,8 @@ def __init__(
import torch
import torchvision

torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torch_version = list(parse(torch.__version__.lower()).release)
torchvision_version = list(parse(torchvision.__version__.lower()).release)
assert not (torch_version[0] == 1 and (torch_version[1] == 8 or torch_version[1] == 9)), (
"PyTorchObjectDetector does not support torch==1.8 and torch==1.9 because of "
"https://github.com/pytorch/vision/issues/4153. Support will return for torch==1.10."
Expand Down

0 comments on commit 8738a5a

Please sign in to comment.