Skip to content

Commit

Permalink
✨🐍 update train script and adds trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Oct 29, 2023
1 parent 0fcdb76 commit e1f6744
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 324 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "lenna_birds_plugin"
authors = ["Christian <[email protected]>"]
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "Plugin to classify birds on images."
license = "MIT"
Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ The file target/release/liblenna_birds_plugin.so can be copied to the plugins fo

[lenna-cli](https://github.com/lenna-project/lenna-cli) and used in the pipeline.

### mobilenet

The plugin can be build with mobilenetv2.

```bash
cargo build --release --features mobilenet
```

### efficientnet

The plugin can be build with efficientnetb2.

```bash
cargo build --release --no-default-features --features efficientnet
```

## wasm and javascript version

The plugin can be compiled to wasm and used on [lenna.app](https://lenna.app).
Expand Down
1 change: 0 additions & 1 deletion assets/Birds-Classifier-EfficientNetB2.onnx

This file was deleted.

Binary file added assets/Birds-Classifier-EfficientNetB2.onnx
Binary file not shown.
Binary file not shown.
Binary file removed assets/birds_efficientnetb2.onnx
Binary file not shown.
24 changes: 4 additions & 20 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
```sh
virtualenv -p python3 .venv
source .venv/bin/activate
pip install tensorflow
pip install kaggle pandas pillow numpy tf2onnx
pip install tqdm tensorboard onnx
```

## optional prepare amd gpu

```sh
sudo apt install rocm-opencl
pip install tensorflow-rocm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6
```

## Download Data
Expand All @@ -32,21 +30,7 @@ python train.py
# Export Model

```sh
python export_onnx_model.py
cp birds_mobilenetv2.onnx ../assets/
cp checkpoints/Birds-Classifier-EfficientNetB2.onnx ../assets/
cp checkpoints/Birds-Classifier-MobileNetV2.onnx ../assets/
cp birds_labels.txt ../assets/
```

## Download EfficientNetB2

```sh
python download_efficientnet.py
```

# torch train version

```sh
pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/rocm5.2
pip install tqdm
python train_torch.py
```
80 changes: 80 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

MODEL_TYPE = 'MobileNetV2'
#MODEL_TYPE = 'EfficientNetB2'

# Define parameters
data_dir = './100-bird-species/'
img_height, img_width = 224, 224
num_labels = 525
checkpoint_dir = './checkpoints/'
checkpoint_file = f'{MODEL_TYPE}-checkpoint.pth'

# Load the validation dataset
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(img_height),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.47853944, 0.4732864, 0.47434163])
])

validation_dataset = datasets.ImageFolder(os.path.join(data_dir, 'valid'), data_transform)
validation_loader = DataLoader(validation_dataset, batch_size=5, shuffle=True)

# Define the model and load checkpoint
model = None

if MODEL_TYPE == 'MobileNetV2':
model = models.mobilenet_v2()
elif MODEL_TYPE == 'EfficientNetB2':
model = models.efficientnet_b2()

num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)

checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
else:
raise ValueError("Checkpoint file not found")

model.eval()

# Select 5 random images and predict
images, labels = next(iter(validation_loader))
with torch.no_grad():
outputs = model(images)
print(outputs)
_, preds = torch.max(outputs, 1)

# Function to convert image for plotting
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.47853944, 0.4732864, 0.47434163])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(1.001)

# Plot the images with predictions and save to a file
fig = plt.figure(figsize=(15, 10))

for i in range(5):
ax = fig.add_subplot(1, 5, i + 1, xticks=[], yticks=[])
imshow(images[i])
ax.set_title(f"True: {validation_dataset.classes[labels[i]]}\nPred: {validation_dataset.classes[preds[i]]}", fontsize=10)

# Save the figure
plt.tight_layout()
plt.savefig('evaluation.jpg')
40 changes: 0 additions & 40 deletions scripts/export_onnx_model.py

This file was deleted.

Loading

0 comments on commit e1f6744

Please sign in to comment.