-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDSS_Experiment.py
70 lines (57 loc) · 2.21 KB
/
DSS_Experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
'''
Author: Mingxin Zhang [email protected]
Date: 2023-07-04 01:27:58
LastEditors: Mingxin Zhang
LastEditTime: 2023-12-07 16:57:02
Copyright (c) 2023 by Mingxin Zhang, All Rights Reserved.
'''
import sys
import numpy as np
from TactileCAAE import model
import torch
import pickle
import sys
import torchaudio
import Methods
import UserInterface
from PyQt5.QtWidgets import QApplication
device = torch.device("cuda")
print(f'Selected device: {device}')
FEAT_DIM = 128
CLASS_NUM = 14
SLIDER_LEN = 30
if __name__ == "__main__":
griffinlim = torchaudio.transforms.GriffinLim(n_fft=2048, n_iter=50, hop_length=int(2048 * 0.1), power=1.0)
griffinlim = griffinlim.to(device)
with open('testset_7-class.pickle', 'rb') as file:
testset = pickle.load(file)
index = np.random.randint(len(testset['spectrogram']))
target_spec = testset['spectrogram'][index]
print(testset['filename'][index])
group = testset['filename'][index][:2]
model_name = 'TactileCAAE'
decoder = model.Generator(feat_dim=FEAT_DIM)
decoder.eval()
decoder.to(device)
# Model initialization and parameter loading
decoder_dict = torch.load(model_name + '/generator_' + str(FEAT_DIM) + 'd.pt', map_location=torch.device('cuda'))
decoder_dict = {k: v for k, v in decoder_dict.items()}
decoder.load_state_dict(decoder_dict)
target_latent = np.random.uniform(-2.5, 2.5, FEAT_DIM)
target_latent = torch.tensor(target_latent).to(torch.float32).to(device)
while True:
random_A = Methods.getRandomAMatrix(FEAT_DIM, 6, np.array(target_latent.reshape(1, -1).cpu()), 1)
if random_A is not None:
break
# random_A = getRandomAMatrix(FEAT_DIM, 6, target_latent.reshape(1, -1), 1)
# initialize the latent
init_z = np.random.uniform(low=-2.5, high=2.5, size=(FEAT_DIM))
init_low_z = np.matmul(np.linalg.pinv(random_A), init_z.T).T
init_z = np.matmul(random_A, init_low_z)
app = QApplication(sys.argv)
window = UserInterface.DSS_Experiment(griffinlim,
target_spec,
decoder,
init_z)
window.show()
sys.exit(app.exec_())