forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinstanceNormalizationPlugin.cu
713 lines (621 loc) · 27.5 KB
/
instanceNormalizationPlugin.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/checkMacrosPlugin.h"
#include "instanceNormalizationPlugin.h"
#include "instanceNormCommon.h"
#include <algorithm>
#include <cuda_fp16.h>
#include <stdexcept>
using namespace nvinfer1;
using namespace nvinfer1::plugin;
using namespace nvinfer1::pluginInternal;
using namespace instance_norm_impl;
using nvinfer1::plugin::InstanceNormalizationV3Plugin;
using nvinfer1::plugin::InstanceNormalizationV3PluginCreator;
namespace
{
constexpr char const* gInstancePluginVersion{"3"};
constexpr char const* gInstancePluginName{"InstanceNormalization_TRT"};
} // namespace
PluginFieldCollection InstanceNormalizationV3PluginCreator::mFC{};
std::vector<PluginField> InstanceNormalizationV3PluginCreator::mPluginAttributes;
InstanceNormalizationV3Plugin::InstanceNormalizationV3Plugin(
float epsilon, std::vector<float> const& scale, std::vector<float> const& bias, int32_t relu, float alpha)
: mEpsilon(epsilon)
, mAlpha(alpha)
, mRelu(relu)
, mNchan(scale.size())
, mHostScale(scale)
, mHostBias(bias)
{
PLUGIN_VALIDATE(scale.size() == bias.size());
}
InstanceNormalizationV3Plugin::InstanceNormalizationV3Plugin(
float epsilon, nvinfer1::Weights const& scale, nvinfer1::Weights const& bias, int32_t relu, float alpha)
: mEpsilon(epsilon)
, mAlpha(alpha)
, mRelu(relu)
, mNchan(scale.count)
{
PLUGIN_VALIDATE(scale.count == bias.count);
auto const copyWeights = [](nvinfer1::Weights const& input, std::vector<float>& output)
{
output.reserve(input.count);
if (input.type == nvinfer1::DataType::kFLOAT)
{
output.assign(
static_cast<float const*>(input.values), static_cast<float const*>(input.values) + input.count);
}
else if (input.type == nvinfer1::DataType::kHALF)
{
for (int32_t c = 0; c < input.count; ++c)
{
auto const value = static_cast<unsigned short const*>(input.values);
output.push_back(__internal_half2float(value[c]));
}
}
else
{
PLUGIN_ERROR("Unsupported scale/bias dtype");
}
};
copyWeights(scale, mHostScale);
copyWeights(bias, mHostBias);
}
InstanceNormalizationV3Plugin::~InstanceNormalizationV3Plugin()
{
exitContext();
}
// InstanceNormalizationV3Plugin returns one output.
int32_t InstanceNormalizationV3Plugin::getNbOutputs() const noexcept
{
return 1;
}
IPluginCapability* InstanceNormalizationV3Plugin::getCapabilityInterface(PluginCapabilityType type) noexcept
{
try
{
if (type == PluginCapabilityType::kBUILD)
{
return static_cast<IPluginV3OneBuild*>(this);
}
if (type == PluginCapabilityType::kRUNTIME)
{
return static_cast<IPluginV3OneRuntime*>(this);
}
PLUGIN_ASSERT(type == PluginCapabilityType::kCORE);
return static_cast<IPluginV3OneCore*>(this);
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
int32_t InstanceNormalizationV3Plugin::initializeContext()
{
if (!mInitialized)
{
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnCreate(&mCudnnHandle));
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnCreateTensorDescriptor(&mBDescriptor));
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnCreateTensorDescriptor(&mXDescriptor));
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnCreateTensorDescriptor(&mYDescriptor));
// NDHWC path
// Device info.
int32_t device;
PLUGIN_CUASSERT(cudaGetDevice(&device));
cudaDeviceProp props;
PLUGIN_CUASSERT(cudaGetDeviceProperties(&props, device));
mContext.sm_count = props.multiProcessorCount;
mContext.sm_shared_size = props.sharedMemPerMultiprocessor;
mContext.sm_version = props.major * 100 + props.minor * 10;
PLUGIN_CUASSERT(cudaMalloc(&mDeviceScale, mNchan * sizeof(float)));
PLUGIN_ASSERT(mDeviceScale != nullptr);
PLUGIN_CUASSERT(cudaMalloc(&mDeviceBias, mNchan * sizeof(float)));
PLUGIN_ASSERT(mDeviceBias != nullptr);
PLUGIN_CUASSERT(cudaMemcpy(mDeviceScale, &mHostScale[0], mNchan * sizeof(float), cudaMemcpyHostToDevice));
PLUGIN_CUASSERT(cudaMemcpy(mDeviceBias, &mHostBias[0], mNchan * sizeof(float), cudaMemcpyHostToDevice));
PLUGIN_CUASSERT(cudaDriverGetVersion(&mCudaDriverVersion));
}
mInitialized = true;
return 0;
}
void InstanceNormalizationV3Plugin::exitContext()
{
if (mInitialized)
{
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnDestroyTensorDescriptor(mYDescriptor));
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnDestroyTensorDescriptor(mXDescriptor));
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnDestroyTensorDescriptor(mBDescriptor));
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnDestroy(mCudnnHandle));
PLUGIN_CUASSERT(cudaFree(mDeviceBias));
PLUGIN_CUASSERT(cudaFree(mDeviceScale));
}
mInitialized = false;
}
size_t InstanceNormalizationV3Plugin::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
{
nvinfer1::Dims input_dims = inputs[0].desc.dims;
PLUGIN_ASSERT(input_dims.nbDims == 4 || input_dims.nbDims == 5);
if (inputs[0].desc.format == nvinfer1::PluginFormat::kLINEAR)
{
nvinfer1::Dims input_dims = inputs[0].desc.dims;
int32_t n = input_dims.d[0];
int32_t c = input_dims.d[1];
size_t nchan_bytes = c * sizeof(float);
size_t scale_size = n * nchan_bytes;
size_t bias_size = n * nchan_bytes;
size_t total_wss = scale_size + bias_size;
return total_wss;
}
else if (inputs[0].desc.format == nvinfer1::PluginFormat::kDHWC8 || inputs[0].desc.format == nvinfer1::PluginFormat::kCDHW32)
{
PLUGIN_ASSERT(input_dims.nbDims == 5);
int32_t input_data_type = (inputs[0].desc.type == nvinfer1::DataType::kHALF) ? 1 : 2;
int32_t output_data_type = (outputs[0].desc.type == nvinfer1::DataType::kHALF) ? 1 : 2;
nvinfer1::Dims input_dims = inputs[0].desc.dims;
int32_t n = input_dims.d[0];
int32_t c = input_dims.d[1];
int32_t d = input_dims.d[2];
int32_t h = input_dims.d[3];
int32_t w = input_dims.d[4];
InstanceNormFwdParams params{};
// only these parameters are required for workspace computation
params.nhw = d * h * w;
params.c = c;
params.n = n;
// Reserve memory for the workspaces.
size_t size_sums, size_counts, size_retired_ctas;
instanceNormBufferSizesDispatch(
mContext, params, size_sums, size_counts, size_retired_ctas, input_data_type, output_data_type);
size_t size_nc = n * c * sizeof(float);
size_nc = ((size_nc + 256 - 1) / 256) * 256;
return size_sums + size_counts + size_retired_ctas + 4 * size_nc;
}
else
{
PLUGIN_ASSERT(0);
}
return 0;
}
int32_t InstanceNormalizationV3Plugin::enqueue(PluginTensorDesc const* inputDesc,
PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr && workspace != nullptr);
nvinfer1::Dims input_dims = inputDesc[0].dims;
// early return for empty tensor
if (std::any_of(input_dims.d, input_dims.d + input_dims.nbDims, [](int32_t d) { return d == 0; }))
{
return 0;
}
auto const callRelu = [this, &stream](void* inOut, int32_t count, nvinfer1::DataType type) {
if (mRelu > 0)
{
int32_t constexpr kBLOCK_SZ = 256;
switch (type)
{
case nvinfer1::DataType::kFLOAT:
in3dReluActivation<float, kBLOCK_SZ><<<divUp(count, kBLOCK_SZ), kBLOCK_SZ, 0, stream>>>(
static_cast<float*>(inOut), static_cast<float*>(inOut), mAlpha, count);
break;
case nvinfer1::DataType::kHALF:
in3dReluActivation<__half, kBLOCK_SZ><<<divUp(count, kBLOCK_SZ), kBLOCK_SZ, 0, stream>>>(
static_cast<__half*>(inOut), static_cast<__half*>(inOut), mAlpha, count);
break;
default: PLUGIN_ASSERT(0);
}
}
};
if (input_dims.nbDims <= 4)
{
nvinfer1::Dims input_dims = inputDesc[0].dims;
int32_t n = input_dims.d[0];
int32_t c = input_dims.d[1];
int32_t h = input_dims.d[2];
int32_t w = input_dims.nbDims > 3 ? input_dims.d[3] : 1;
size_t nchan_bytes = c * sizeof(float);
float* _d_array = static_cast<float*>(workspace);
float* d_scale = &_d_array[0];
float* d_bias = &_d_array[n * c];
for (int32_t i = 0; i < n; ++i)
{
PLUGIN_CUASSERT(
cudaMemcpyAsync(d_scale + i * c, mDeviceScale, nchan_bytes, cudaMemcpyDeviceToDevice, stream));
PLUGIN_CUASSERT(
cudaMemcpyAsync(d_bias + i * c, mDeviceBias, nchan_bytes, cudaMemcpyDeviceToDevice, stream));
}
PLUGIN_CUDNNASSERT(
mCudnnWrapper.cudnnSetTensor4dDescriptor(mBDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1));
cudnnDataType_t cudnn_dtype{};
PLUGIN_CUDNNASSERT(convertTrt2cudnnDtype(inputDesc[0].type, &cudnn_dtype));
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnSetTensor4dDescriptor(mXDescriptor, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w));
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnSetTensor4dDescriptor(mYDescriptor, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w));
float alpha = 1;
float beta = 0;
void const* x_ptr = inputs[0];
void* y_ptr = outputs[0];
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnSetStream(mCudnnHandle, stream));
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
// overflows (NaNs) for fp32 data in some circumstances. The lower-
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
// acceptable.
cudnnBatchNormMode_t cudnnBatchNormMode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
cudaStreamCaptureStatus streamStatus;
PLUGIN_CUASSERT(cudaStreamIsCapturing(stream, &streamStatus));
if (streamStatus != cudaStreamCaptureStatusNone && mCudaDriverVersion < 11000)
{
gLogVerbose << "Using CUDNN_BATCHNORM_SPATIAL as a CUDA graph capture is in progress but the CUDA version "
"may have issues with using CUDNN_BATCHNORM_SPATIAL_PERSISTENT"
<< std::endl;
cudnnBatchNormMode = CUDNN_BATCHNORM_SPATIAL;
}
PLUGIN_CUDNNASSERT(mCudnnWrapper.cudnnBatchNormalizationForwardTraining(mCudnnHandle, cudnnBatchNormMode,
&alpha, &beta, mXDescriptor, x_ptr, mYDescriptor, y_ptr, mBDescriptor, d_scale, d_bias, 1., nullptr,
nullptr, mEpsilon, nullptr, nullptr));
callRelu(y_ptr, n * c * h * w, inputDesc[0].type);
}
else
{
if (inputDesc[0].format == nvinfer1::PluginFormat::kLINEAR)
{
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnSetStream(mCudnnHandle, stream));
nvinfer1::Dims input_dims = inputDesc[0].dims;
int32_t n = input_dims.d[0];
int32_t c = input_dims.d[1];
int32_t d = input_dims.d[2];
int32_t h = input_dims.d[3];
int32_t w = input_dims.d[4];
size_t nchan_bytes = c * sizeof(float);
// Note: We repeat the data for each batch entry so that we can do the full
// computation in a single CUDNN call in enqueue().
float* _d_array = (float*) workspace;
float* d_scale = &_d_array[0];
float* d_bias = &_d_array[n * c];
for (int32_t i = 0; i < n; ++i)
{
PLUGIN_CUASSERT(
cudaMemcpyAsync(d_scale + i * c, mDeviceScale, nchan_bytes, cudaMemcpyDeviceToDevice, stream));
PLUGIN_CUASSERT(
cudaMemcpyAsync(d_bias + i * c, mDeviceBias, nchan_bytes, cudaMemcpyDeviceToDevice, stream));
}
int32_t nc_dimA[] = {1, n * c, 1, 1, 1};
int32_t nc_strideA[] = {nc_dimA[1] * nc_dimA[2] * nc_dimA[3] * nc_dimA[4],
nc_dimA[2] * nc_dimA[3] * nc_dimA[4], nc_dimA[3] * nc_dimA[4], nc_dimA[4], 1};
int32_t img_dimA[] = {1, n * c, d, h, w};
int32_t img_strideA[] = {img_dimA[1] * img_dimA[2] * img_dimA[3] * img_dimA[4],
img_dimA[2] * img_dimA[3] * img_dimA[4], img_dimA[3] * img_dimA[4], img_dimA[4], 1};
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnSetTensorNdDescriptor(mBDescriptor, CUDNN_DATA_FLOAT, 5, nc_dimA, nc_strideA));
cudnnDataType_t cudnn_dtype;
PLUGIN_CHECK_CUDNN(convertTrt2cudnnDtype(inputDesc[0].type, &cudnn_dtype));
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnSetTensorNdDescriptor(mXDescriptor, cudnn_dtype, 5, img_dimA, img_strideA));
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnSetTensorNdDescriptor(mYDescriptor, cudnn_dtype, 5, img_dimA, img_strideA));
float alpha = 1;
float beta = 0;
void const* x_ptr = inputs[0];
void* y_ptr = outputs[0];
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
// overflows (NaNs) for fp32 data in some circumstances. The lower-
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
// acceptable.
cudnnBatchNormMode_t cudnnBatchNormMode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
cudaStreamCaptureStatus streamStatus;
PLUGIN_CUASSERT(cudaStreamIsCapturing(stream, &streamStatus));
if (streamStatus != cudaStreamCaptureStatusNone && mCudaDriverVersion < 11000)
{
gLogVerbose
<< "Using CUDNN_BATCHNORM_SPATIAL as a CUDA graph capture is in progress but the CUDA version "
"may have issues with using CUDNN_BATCHNORM_SPATIAL_PERSISTENT"
<< std::endl;
cudnnBatchNormMode = CUDNN_BATCHNORM_SPATIAL;
}
PLUGIN_CHECK_CUDNN(mCudnnWrapper.cudnnBatchNormalizationForwardTraining(mCudnnHandle, cudnnBatchNormMode,
&alpha, &beta, mXDescriptor, x_ptr, mYDescriptor, y_ptr, mBDescriptor, d_scale, d_bias, 1., nullptr,
nullptr, mEpsilon, nullptr, nullptr));
callRelu(y_ptr, n * c * d * h * w, inputDesc[0].type);
}
else if (inputDesc[0].format == nvinfer1::PluginFormat::kDHWC8
|| inputDesc[0].format == nvinfer1::PluginFormat::kCDHW32)
{
int32_t input_data_type = (inputDesc[0].type == nvinfer1::DataType::kHALF) ? 1 : 2;
int32_t output_data_type = (outputDesc[0].type == nvinfer1::DataType::kHALF) ? 1 : 2;
nvinfer1::Dims input_dims = inputDesc[0].dims;
int32_t n = input_dims.d[0];
int32_t c = input_dims.d[1];
int32_t d = input_dims.d[2];
int32_t h = input_dims.d[3];
int32_t w = input_dims.d[4];
InstanceNormFwdParams params{};
params.nhw = d * h * w;
params.c = c;
params.n = n;
size_t size_sums, size_counts, size_retired_ctas;
instanceNormBufferSizesDispatch(
mContext, params, size_sums, size_counts, size_retired_ctas, input_data_type, output_data_type);
size_t size_nc = n * c * sizeof(float);
size_nc = ((size_nc + 256 - 1) / 256) * 256;
char* d_buf = static_cast<char*>(workspace);
params.gmem_sums = reinterpret_cast<GMEM_SUMS_TYPE*>(d_buf);
d_buf += size_sums;
params.gmem_counts = reinterpret_cast<int32_t*>(d_buf);
d_buf += size_counts;
params.gmem_retired_ctas = reinterpret_cast<int32_t*>(d_buf);
d_buf += size_retired_ctas;
params.gmem_running_mean = reinterpret_cast<float*>(d_buf);
d_buf += size_nc;
params.gmem_running_var = reinterpret_cast<float*>(d_buf);
d_buf += size_nc;
params.gmem_saved_mean = reinterpret_cast<float*>(d_buf);
d_buf += size_nc;
params.gmem_saved_var = reinterpret_cast<float*>(d_buf);
d_buf += size_nc;
params.gmem_src = inputs[0];
params.gmem_dst = outputs[0];
params.gmem_bias = mDeviceBias;
params.gmem_scale = mDeviceScale;
params.var_eps = mEpsilon;
params.exp_avg_factor = 1.F; //(float)exp_avg_factor;
params.use_relu = mRelu; // use_relu;
params.relu_alpha = mAlpha; // relu_alpha;
params.in_scale = inputDesc[0].scale;
PLUGIN_ASSERT(outputDesc[0].scale != 0.F);
params.out_scale = 1.F / outputDesc[0].scale;
instanceNormFwdDispatch(mContext, params, stream, input_data_type, output_data_type);
}
else
{
PLUGIN_FAIL("Unexpected input format");
}
}
return 0;
}
bool InstanceNormalizationV3Plugin::supportsFormatCombination(
int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
PLUGIN_ASSERT(inOut && pos < (nbInputs + nbOutputs));
PLUGIN_ASSERT(pos == 0 || pos == 1);
// For 4-D or 3-D tensor (nbSpatialDims == 1 or 2), only FP32_Linear and FP16_Linear are supported.
// For 5-D tensor (nbSpatialDims == 3), FP32_Linear, FP16_Linear, FP16_DHWC8, and INT8_CDHW32 are supported.
// This is because we have special InstanceNorm3D kernels for vectorized formats from MLPerf-Inference.
int32_t const nbDims = inOut[pos].desc.dims.nbDims;
PLUGIN_ASSERT(nbDims >= 3);
PLUGIN_ASSERT(nbDims <= 5);
bool const is3DInstanceNorm = (nbDims == 5);
bool const isFP32Linear
= (inOut[pos].desc.type == nvinfer1::DataType::kFLOAT && inOut[pos].desc.format == nvinfer1::PluginFormat::kLINEAR
&& inOut[pos].desc.type == inOut[0].desc.type && inOut[pos].desc.format == inOut[0].desc.format);
bool const isFP16Linear
= (inOut[pos].desc.type == nvinfer1::DataType::kHALF && inOut[pos].desc.format == nvinfer1::PluginFormat::kLINEAR
&& inOut[pos].desc.type == inOut[0].desc.type && inOut[pos].desc.format == inOut[0].desc.format);
bool const isFP16DHWC8
= (inOut[pos].desc.type == nvinfer1::DataType::kHALF && inOut[pos].desc.format == nvinfer1::PluginFormat::kDHWC8
&& inOut[pos].desc.type == inOut[0].desc.type && inOut[pos].desc.format == inOut[0].desc.format);
bool const isINT8CDHW32
= (inOut[pos].desc.type == nvinfer1::DataType::kINT8 && inOut[pos].desc.format == nvinfer1::PluginFormat::kCDHW32
&& inOut[pos].desc.type == inOut[0].desc.type && inOut[pos].desc.format == inOut[0].desc.format);
bool const isFormatOK = isFP32Linear || isFP16Linear || (is3DInstanceNorm && (isFP16DHWC8 || isINT8CDHW32));
// Kernels for vectorized formats only support the case of C % spv == 0.
int32_t spv{1};
switch (inOut[pos].desc.format)
{
case nvinfer1::PluginFormat::kDHWC8: spv = 8; break;
case nvinfer1::PluginFormat::kCDHW32: spv = 32; break;
default: break;
}
int32_t const isAlignmentOK = (inOut[pos].desc.dims.d[1] % spv == 0);
return isFormatOK && isAlignmentOK;
}
char const* InstanceNormalizationV3Plugin::getPluginName() const noexcept
{
return gInstancePluginName;
}
char const* InstanceNormalizationV3Plugin::getPluginVersion() const noexcept
{
return gInstancePluginVersion;
}
char const* InstanceNormalizationV3Plugin::getPluginNamespace() const noexcept
{
return mPluginNamespace.c_str();
}
InstanceNormalizationV3Plugin* InstanceNormalizationV3Plugin::clone() noexcept
{
try
{
auto* plugin = new InstanceNormalizationV3Plugin{mEpsilon, mHostScale, mHostBias, mRelu, mAlpha};
plugin->setPluginNamespace(mPluginNamespace.c_str());
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
// Set plugin namespace
void InstanceNormalizationV3Plugin::setPluginNamespace(char const* pluginNamespace) noexcept
{
try
{
PLUGIN_ASSERT(pluginNamespace != nullptr);
mPluginNamespace = pluginNamespace;
}
catch (std::exception const& e)
{
caughtError(e);
}
}
int32_t InstanceNormalizationV3Plugin::getOutputDataTypes(
DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
{
PLUGIN_ASSERT(inputTypes != nullptr);
PLUGIN_ASSERT(nbInputs == 1);
PLUGIN_ASSERT(nbOutputs == 1);
outputTypes[0] = inputTypes[0];
return 0;
}
int32_t InstanceNormalizationV3Plugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
PLUGIN_ASSERT(inputs != nullptr);
PLUGIN_ASSERT(nbInputs == 1);
PLUGIN_ASSERT(nbOutputs == 1);
outputs[0] = inputs[0];
return 0;
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
IPluginV3* InstanceNormalizationV3Plugin::attachToContext(IPluginResourceContext* context) noexcept
{
InstanceNormalizationV3Plugin* obj = clone();
obj->initializeContext();
return obj;
}
int32_t InstanceNormalizationV3Plugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept
{
return STATUS_SUCCESS;
}
int32_t InstanceNormalizationV3Plugin::onShapeChange(PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept
{
PLUGIN_ASSERT(in != nullptr);
PLUGIN_ASSERT(out != nullptr);
PLUGIN_ASSERT(nbOutputs == 1);
PLUGIN_ASSERT(nbInputs == 1);
// Not support dynamic shape in C dimension
PLUGIN_ASSERT(in[0].dims.d[1] != -1);
return STATUS_SUCCESS;
}
PluginFieldCollection const* InstanceNormalizationV3Plugin::getFieldsToSerialize() noexcept
{
mDataToSerialize.clear();
mDataToSerialize.emplace_back("epsilon", &mEpsilon, PluginFieldType::kFLOAT32, 1);
mDataToSerialize.emplace_back("scales", mHostScale.data(), PluginFieldType::kFLOAT32, mHostScale.size());
mDataToSerialize.emplace_back("bias", mHostBias.data(), PluginFieldType::kFLOAT32, mHostBias.size());
mDataToSerialize.emplace_back("relu", &mRelu, PluginFieldType::kINT32, 1);
mDataToSerialize.emplace_back("alpha", &mAlpha, PluginFieldType::kFLOAT32, 1);
mFCToSerialize.nbFields = mDataToSerialize.size();
mFCToSerialize.fields = mDataToSerialize.data();
return &mFCToSerialize;
}
// InstanceNormalizationV3PluginCreator methods
InstanceNormalizationV3PluginCreator::InstanceNormalizationV3PluginCreator()
{
static std::mutex sMutex;
std::lock_guard<std::mutex> guard(sMutex);
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("scales", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("relu", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("alpha", nullptr, PluginFieldType::kFLOAT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* InstanceNormalizationV3PluginCreator::getPluginName() const noexcept
{
return gInstancePluginName;
}
char const* InstanceNormalizationV3PluginCreator::getPluginVersion() const noexcept
{
return gInstancePluginVersion;
}
PluginFieldCollection const* InstanceNormalizationV3PluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV3* InstanceNormalizationV3PluginCreator::createPlugin(
char const* name, nvinfer1::PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
std::vector<float> scaleValues;
std::vector<float> biasValues;
float epsilon{};
int32_t relu{};
float alpha{};
PluginField const* fields = fc->fields;
for (int32_t i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "epsilon"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32);
epsilon = *(static_cast<float const*>(fields[i].data));
}
else if (!strcmp(attrName, "scales"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32);
int32_t size = fields[i].length;
scaleValues.reserve(size);
auto const* w = static_cast<float const*>(fields[i].data);
for (int32_t j = 0; j < size; j++)
{
scaleValues.push_back(*w);
w++;
}
}
else if (!strcmp(attrName, "bias"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32);
int32_t size = fields[i].length;
biasValues.reserve(size);
auto const* w = static_cast<float const*>(fields[i].data);
for (int32_t j = 0; j < size; j++)
{
biasValues.push_back(*w);
w++;
}
}
else if (!strcmp(attrName, "relu"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32);
relu = *(static_cast<int32_t const*>(fields[i].data));
}
else if (!strcmp(attrName, "alpha"))
{
PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32);
alpha = *(static_cast<float const*>(fields[i].data));
}
}
Weights scaleWeights{DataType::kFLOAT, scaleValues.data(), (int64_t) scaleValues.size()};
Weights biasWeights{DataType::kFLOAT, biasValues.data(), (int64_t) biasValues.size()};
auto* obj = new InstanceNormalizationV3Plugin(epsilon, scaleWeights, biasWeights, relu, alpha);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void InstanceNormalizationV3PluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{
try
{
PLUGIN_VALIDATE(libNamespace != nullptr);
mNamespace = libNamespace;
}
catch (std::exception const& e)
{
caughtError(e);
}
}
char const* InstanceNormalizationV3PluginCreator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}