Skip to content

Commit

Permalink
Update Torch profiler (#1156)
Browse files Browse the repository at this point in the history
* updated pytorch weight profiler

* fix type

* [pre-commit.ci] auto fixes from pre-commit hooks

* update comparison to false

* fixed numerical condition for pytorch models and updates to type hints

* Create test_pytorch_profiler.py

* Update layer processing and add batchnorm testing

* Remove typo

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jan-Frederik Schulte <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent fb07e9c commit fd594e0
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 16 deletions.
115 changes: 99 additions & 16 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,87 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):


def weights_torch(model, fmt='longform', plot='boxplot'):
suffix = ['w', 'b']
if fmt == 'longform':
data = {'x': [], 'layer': [], 'weight': []}
elif fmt == 'summary':
data = []
for layer in model.children():
if isinstance(layer, torch.nn.Linear):
wt = WeightsTorch(model, fmt, plot)
return wt.get_weights()


def _torch_batchnorm(layer):
weights = list(layer.parameters())
epsilon = layer.eps

gamma = weights[0]
beta = weights[1]
if layer.track_running_stats:
mean = layer.running_mean
var = layer.running_var
else:
mean = torch.tensor(np.ones(20))
var = torch.tensor(np.zeros(20))

scale = gamma / np.sqrt(var + epsilon)
bias = beta - gamma * mean / np.sqrt(var + epsilon)

return [scale, bias], ['s', 'b']


def _torch_layer(layer):
return list(layer.parameters()), ['w', 'b']


def _torch_rnn(layer):
return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0']


torch_process_layer_map = defaultdict(
lambda: _torch_layer,
{
'BatchNorm1d': _torch_batchnorm,
'BatchNorm2d': _torch_batchnorm,
'RNN': _torch_rnn,
'LSTM': _torch_rnn,
'GRU': _torch_rnn,
},
)


class WeightsTorch:
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
self.model = model
self.fmt = fmt
self.plot = plot
self.registered_layers = list()
self._find_layers(self.model, self.model.__class__.__name__)

def _find_layers(self, model, module_name):
for name, module in model.named_children():
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
self._find_layers(module, module_name + "." + name)
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
if len(list(module.named_children())) != 0:
# custom nn.Module, continue search
self._find_layers(module, module_name + "." + name)
else:
self._register_layer(module_name + "." + name)

def _is_registered(self, name: str) -> bool:
return name in self.registered_layers

def _register_layer(self, name: str) -> None:
if self._is_registered(name) is False:
self.registered_layers.append(name)

def _is_parameterized(self, module: torch.nn.Module) -> bool:
return any(p.requires_grad for p in module.parameters())

def _get_weights(self) -> pandas.DataFrame | list[dict]:
if self.fmt == 'longform':
data = {'x': [], 'layer': [], 'weight': []}
elif self.fmt == 'summary':
data = []
for layer_name in self.registered_layers:
layer = self._get_layer(layer_name, self.model)
name = layer.__class__.__name__
weights = list(layer.parameters())
weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer)
for i, w in enumerate(weights):
label = f'{name}/{suffix[i]}'
w = weights[i].detach().numpy()
Expand All @@ -399,18 +471,29 @@ def weights_torch(model, fmt='longform', plot='boxplot'):
if n == 0:
print(f'Weights for {name} are only zeros, ignoring.')
break
if fmt == 'longform':
if self.fmt == 'longform':
data['x'].extend(w.tolist())
data['layer'].extend([name] * n)
data['weight'].extend([label] * n)
elif fmt == 'summary':
data.append(array_to_summary(w, fmt=plot))
elif self.fmt == 'summary':
data.append(array_to_summary(w, fmt=self.plot))
data[-1]['layer'] = name
data[-1]['weight'] = label

if fmt == 'longform':
data = pandas.DataFrame(data)
return data
if self.fmt == 'longform':
data = pandas.DataFrame(data)
return data

def get_weights(self) -> pandas.DataFrame | list[dict]:
return self._get_weights()

def get_layers(self) -> list[str]:
return self.registered_layers

def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
for name in layer_name.split('.')[1:]:
module = getattr(module, name)
return module


def activations_torch(model, X, fmt='longform', plot='boxplot'):
Expand Down Expand Up @@ -484,11 +567,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
elif model_present:
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
data = weights_keras(model, fmt='summary', plot=plot)
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module):
data = weights_torch(model, fmt='summary', plot=plot)

if data is None:
print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled")
print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled")

if hls_model_present and os.path.exists(tmp_output_dir):
shutil.rmtree(tmp_output_dir)
Expand Down
85 changes: 85 additions & 0 deletions test/pytest/test_pytorch_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest

import hls4ml

try:
import torch
import torch.nn as nn

__torch_profiling_enabled__ = True
except ImportError:
__torch_profiling_enabled__ = False


class SubClassModel(torch.nn.Module):
def __init__(self, layers) -> None:
super().__init__()
for idx, layer in enumerate(layers):
setattr(self, f'layer_{idx}', layer)


class ModuleListModel(torch.nn.Module):
def __init__(self, layers) -> None:
super().__init__()
self.layer = torch.nn.ModuleList(layers)


class NestedSequentialModel(torch.nn.Module):
def __init__(self, layers) -> None:
super().__init__()
self.model = torch.nn.Sequential(*layers)


def count_bars_in_figure(fig):
count = 0
for ax in fig.get_axes():
count += len(ax.patches)
return count


# Reusable parameter list
test_layers = [
(4, [nn.Linear(10, 20), nn.Linear(20, 5)]),
(3, [nn.Linear(10, 20), nn.BatchNorm1d(20)]),
(6, [nn.Linear(10, 20), nn.Linear(20, 5), nn.Conv1d(3, 16, kernel_size=3)]),
(6, [nn.Linear(15, 30), nn.Linear(30, 15), nn.Conv2d(1, 32, kernel_size=3)]),
(6, [nn.RNN(64, 128), nn.Linear(128, 10)]),
(6, [nn.LSTM(64, 128), nn.Linear(128, 10)]),
(6, [nn.GRU(64, 128), nn.Linear(128, 10)]),
]


@pytest.mark.parametrize("layers", test_layers)
def test_sequential_model(layers):
if __torch_profiling_enabled__:
param_count, layers = layers
model = torch.nn.Sequential(*layers)
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
assert count_bars_in_figure(wp) == param_count


@pytest.mark.parametrize("layers", test_layers)
def test_subclass_model(layers):
if __torch_profiling_enabled__:
param_count, layers = layers
model = SubClassModel(layers)
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
assert count_bars_in_figure(wp) == param_count


@pytest.mark.parametrize("layers", test_layers)
def test_modulelist_model(layers):
if __torch_profiling_enabled__:
param_count, layers = layers
model = ModuleListModel(layers)
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
assert count_bars_in_figure(wp) == param_count


@pytest.mark.parametrize("layers", test_layers)
def test_nested_model(layers):
if __torch_profiling_enabled__:
param_count, layers = layers
model = NestedSequentialModel(layers)
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
assert count_bars_in_figure(wp) == param_count

0 comments on commit fd594e0

Please sign in to comment.