-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdataset.py
78 lines (69 loc) · 3.06 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import glob
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from skimage import io, transform
from PIL import Image
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
trans = transforms.Compose([transforms.RandomCrop(256),
transforms.ToTensor(),
normalize])
def denorm(tensor, device):
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)
res = torch.clamp(tensor * std + mean, 0, 1)
return res
class PreprocessDataset(Dataset):
def __init__(self, content_dir, style_dir, transforms=trans):
content_dir_resized = content_dir + '_resized'
style_dir_resized = style_dir + '_resized'
if not (os.path.exists(content_dir_resized) and
os.path.exists(style_dir_resized)):
os.mkdir(content_dir_resized)
os.mkdir(style_dir_resized)
self._resize(content_dir, content_dir_resized)
self._resize(style_dir, style_dir_resized)
content_images = glob.glob((content_dir_resized + '/*'))
np.random.shuffle(content_images)
style_images = glob.glob(style_dir_resized + '/*')
np.random.shuffle(style_images)
self.images_pairs = list(zip(content_images, style_images))
self.transforms = transforms
@staticmethod
def _resize(source_dir, target_dir):
print(f'Start resizing {source_dir} ')
for i in tqdm(os.listdir(source_dir)):
filename = os.path.basename(i)
try:
image = io.imread(os.path.join(source_dir, i))
if len(image.shape) == 3 and image.shape[-1] == 3:
H, W, _ = image.shape
if H < W:
ratio = W / H
H = 512
W = int(ratio * H)
else:
ratio = H / W
W = 512
H = int(ratio * W)
image = transform.resize(image, (H, W), mode='reflect', anti_aliasing=True)
io.imsave(os.path.join(target_dir, filename), image)
except:
continue
def __len__(self):
return len(self.images_pairs)
def __getitem__(self, index):
content_image, style_image = self.images_pairs[index]
content_image = Image.open(content_image)
style_image = Image.open(style_image)
# content_image = io.imread(content_image, plugin='pil')
# style_image = io.imread(style_image, plugin='pil')
# Unfortunately,RandomCrop doesn't work with skimage.io
if self.transforms:
content_image = self.transforms(content_image)
style_image = self.transforms(style_image)
return content_image, style_image