From fd594e02d1a9f5b97bda9714732624a6a7ed6343 Mon Sep 17 00:00:00 2001 From: jicampos Date: Thu, 16 Jan 2025 14:17:11 -0600 Subject: [PATCH] Update Torch profiler (#1156) * updated pytorch weight profiler * fix type * [pre-commit.ci] auto fixes from pre-commit hooks * update comparison to false * fixed numerical condition for pytorch models and updates to type hints * Create test_pytorch_profiler.py * Update layer processing and add batchnorm testing * Remove typo --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jan-Frederik Schulte --- hls4ml/model/profiling.py | 115 +++++++++++++++++++++++---- test/pytest/test_pytorch_profiler.py | 85 ++++++++++++++++++++ 2 files changed, 184 insertions(+), 16 deletions(-) create mode 100644 test/pytest/test_pytorch_profiler.py diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 84a83de23e..f30088b51d 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -381,15 +381,87 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'): def weights_torch(model, fmt='longform', plot='boxplot'): - suffix = ['w', 'b'] - if fmt == 'longform': - data = {'x': [], 'layer': [], 'weight': []} - elif fmt == 'summary': - data = [] - for layer in model.children(): - if isinstance(layer, torch.nn.Linear): + wt = WeightsTorch(model, fmt, plot) + return wt.get_weights() + + +def _torch_batchnorm(layer): + weights = list(layer.parameters()) + epsilon = layer.eps + + gamma = weights[0] + beta = weights[1] + if layer.track_running_stats: + mean = layer.running_mean + var = layer.running_var + else: + mean = torch.tensor(np.ones(20)) + var = torch.tensor(np.zeros(20)) + + scale = gamma / np.sqrt(var + epsilon) + bias = beta - gamma * mean / np.sqrt(var + epsilon) + + return [scale, bias], ['s', 'b'] + + +def _torch_layer(layer): + return list(layer.parameters()), ['w', 'b'] + + +def _torch_rnn(layer): + return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0'] + + +torch_process_layer_map = defaultdict( + lambda: _torch_layer, + { + 'BatchNorm1d': _torch_batchnorm, + 'BatchNorm2d': _torch_batchnorm, + 'RNN': _torch_rnn, + 'LSTM': _torch_rnn, + 'GRU': _torch_rnn, + }, +) + + +class WeightsTorch: + def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None: + self.model = model + self.fmt = fmt + self.plot = plot + self.registered_layers = list() + self._find_layers(self.model, self.model.__class__.__name__) + + def _find_layers(self, model, module_name): + for name, module in model.named_children(): + if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)): + self._find_layers(module, module_name + "." + name) + elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): + if len(list(module.named_children())) != 0: + # custom nn.Module, continue search + self._find_layers(module, module_name + "." + name) + else: + self._register_layer(module_name + "." + name) + + def _is_registered(self, name: str) -> bool: + return name in self.registered_layers + + def _register_layer(self, name: str) -> None: + if self._is_registered(name) is False: + self.registered_layers.append(name) + + def _is_parameterized(self, module: torch.nn.Module) -> bool: + return any(p.requires_grad for p in module.parameters()) + + def _get_weights(self) -> pandas.DataFrame | list[dict]: + if self.fmt == 'longform': + data = {'x': [], 'layer': [], 'weight': []} + elif self.fmt == 'summary': + data = [] + for layer_name in self.registered_layers: + layer = self._get_layer(layer_name, self.model) name = layer.__class__.__name__ - weights = list(layer.parameters()) + weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer) for i, w in enumerate(weights): label = f'{name}/{suffix[i]}' w = weights[i].detach().numpy() @@ -399,18 +471,29 @@ def weights_torch(model, fmt='longform', plot='boxplot'): if n == 0: print(f'Weights for {name} are only zeros, ignoring.') break - if fmt == 'longform': + if self.fmt == 'longform': data['x'].extend(w.tolist()) data['layer'].extend([name] * n) data['weight'].extend([label] * n) - elif fmt == 'summary': - data.append(array_to_summary(w, fmt=plot)) + elif self.fmt == 'summary': + data.append(array_to_summary(w, fmt=self.plot)) data[-1]['layer'] = name data[-1]['weight'] = label - if fmt == 'longform': - data = pandas.DataFrame(data) - return data + if self.fmt == 'longform': + data = pandas.DataFrame(data) + return data + + def get_weights(self) -> pandas.DataFrame | list[dict]: + return self._get_weights() + + def get_layers(self) -> list[str]: + return self.registered_layers + + def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module: + for name in layer_name.split('.')[1:]: + module = getattr(module, name) + return module def activations_torch(model, X, fmt='longform', plot='boxplot'): @@ -484,11 +567,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): elif model_present: if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) - elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential): + elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module): data = weights_torch(model, fmt='summary', plot=plot) if data is None: - print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled") + print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled") if hls_model_present and os.path.exists(tmp_output_dir): shutil.rmtree(tmp_output_dir) diff --git a/test/pytest/test_pytorch_profiler.py b/test/pytest/test_pytorch_profiler.py new file mode 100644 index 0000000000..746bfc9455 --- /dev/null +++ b/test/pytest/test_pytorch_profiler.py @@ -0,0 +1,85 @@ +import pytest + +import hls4ml + +try: + import torch + import torch.nn as nn + + __torch_profiling_enabled__ = True +except ImportError: + __torch_profiling_enabled__ = False + + +class SubClassModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + for idx, layer in enumerate(layers): + setattr(self, f'layer_{idx}', layer) + + +class ModuleListModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + self.layer = torch.nn.ModuleList(layers) + + +class NestedSequentialModel(torch.nn.Module): + def __init__(self, layers) -> None: + super().__init__() + self.model = torch.nn.Sequential(*layers) + + +def count_bars_in_figure(fig): + count = 0 + for ax in fig.get_axes(): + count += len(ax.patches) + return count + + +# Reusable parameter list +test_layers = [ + (4, [nn.Linear(10, 20), nn.Linear(20, 5)]), + (3, [nn.Linear(10, 20), nn.BatchNorm1d(20)]), + (6, [nn.Linear(10, 20), nn.Linear(20, 5), nn.Conv1d(3, 16, kernel_size=3)]), + (6, [nn.Linear(15, 30), nn.Linear(30, 15), nn.Conv2d(1, 32, kernel_size=3)]), + (6, [nn.RNN(64, 128), nn.Linear(128, 10)]), + (6, [nn.LSTM(64, 128), nn.Linear(128, 10)]), + (6, [nn.GRU(64, 128), nn.Linear(128, 10)]), +] + + +@pytest.mark.parametrize("layers", test_layers) +def test_sequential_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = torch.nn.Sequential(*layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_subclass_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = SubClassModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_modulelist_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = ModuleListModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count + + +@pytest.mark.parametrize("layers", test_layers) +def test_nested_model(layers): + if __torch_profiling_enabled__: + param_count, layers = layers + model = NestedSequentialModel(layers) + wp, _, _, _ = hls4ml.model.profiling.numerical(model) + assert count_bars_in_figure(wp) == param_count