Skip to content

Commit

Permalink
fix style checks
Browse files Browse the repository at this point in the history
  • Loading branch information
f4str committed Nov 9, 2023
1 parent ddef573 commit baeec58
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np

from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector

if TYPE_CHECKING:
Expand Down Expand Up @@ -161,7 +163,7 @@ def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> List[Any

return labels_translated

def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List[Dict[str, "torch.Tensor"]]:
def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List[Dict[str, np.ndarray]]:
"""
Translate object detection predictions from the model format (DETR) to ART format (torchvision) and
convert tensors to numpy arrays.
Expand All @@ -181,7 +183,7 @@ def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List
pred_boxes = predictions["pred_boxes"]
pred_logits = predictions["pred_logits"]

predictions_x1y1x2y2 = []
predictions_x1y1x2y2: List[Dict[str, np.ndarray]] = []

for pred_box, pred_logit in zip(pred_boxes, pred_logits):
boxes = rescale_bboxes(pred_box.detach().cpu(), (height, width)).numpy()
Expand Down
8 changes: 4 additions & 4 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
self._attack_losses = attack_losses

# Parameters used for subclasses
self.weight_dict = None
self.criterion = None
self.weight_dict: Optional[Dict[str, float]] = None
self.criterion: Optional[torch.nn.Module] = None

if self.clip_values is not None:
if self.clip_values[0] != 0:
Expand Down Expand Up @@ -577,6 +577,6 @@ def compute_loss( # type: ignore
)

if isinstance(x, torch.Tensor):
return loss
return loss # type: ignore

return loss.detach().cpu().numpy()
return loss.detach().cpu().numpy() # type: ignore
6 changes: 4 additions & 2 deletions art/estimators/object_detection/pytorch_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import logging
from typing import List, Dict, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np

from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector

if TYPE_CHECKING:
Expand Down Expand Up @@ -142,7 +144,7 @@ def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> "torch.T
labels_xcycwh = torch.vstack(labels_xcycwh_list)
return labels_xcycwh

def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str, "torch.Tensor"]]:
def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str, np.ndarray]]:
"""
Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and
convert tensors to numpy arrays.
Expand All @@ -159,7 +161,7 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str,
height = self.input_shape[0]
width = self.input_shape[1]

predictions_x1y1x2y2 = []
predictions_x1y1x2y2: List[Dict[str, np.ndarray]] = []

for pred in predictions:
boxes = torch.vstack(
Expand Down

0 comments on commit baeec58

Please sign in to comment.