diff --git a/demo/visualization_demo/bev_vis_multi_frame_demo.py b/demo/visualization_demo/bev_vis_multi_frame_demo.py new file mode 100644 index 00000000..69ec8bf8 --- /dev/null +++ b/demo/visualization_demo/bev_vis_multi_frame_demo.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('bev') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/bev_vis_single_frame_demo.py b/demo/visualization_demo/bev_vis_single_frame_demo.py new file mode 100644 index 00000000..f28aa04d --- /dev/null +++ b/demo/visualization_demo/bev_vis_single_frame_demo.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +from paddle.inference import Config, create_predictor +from paddle3d.ops.iou3d_nms_cuda import nms_gpu +from demo.visualization_demo.vis_utils import preprocess, Calibration, show_bev_with_boxes + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--lidar_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + '--calib_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + "--num_point_dim", + type=int, + default=4, + help="Dimension of a point in the lidar file.") + parser.add_argument( + "--point_cloud_range", + dest='point_cloud_range', + nargs='+', + help="Range of point cloud for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--voxel_size", + dest='voxel_size', + nargs='+', + help="Size of voxels for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--max_points_in_voxel", + type=int, + default=100, + help="Maximum number of points in a voxel.") + parser.add_argument( + "--max_voxel_num", + type=int, + default=12000, + help="Maximum number of voxels.") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.") + parser.add_argument( + "--use_trt", + type=int, + default=0, + help="Whether to use tensorrt to accelerate when using gpu.") + parser.add_argument( + "--trt_precision", + type=int, + default=0, + help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.") + parser.add_argument( + "--trt_use_static", + type=int, + default=0, + help="Whether to load the tensorrt graph optimization from a disk path." + ) + parser.add_argument( + "--trt_static_dir", + type=str, + help="Path of a tensorrt graph optimization directory.") + parser.add_argument( + "--collect_shape_info", + type=int, + default=0, + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + type=str, + default="", + help="Path of a dynamic shape file for tensorrt.") + + return parser.parse_args() + + +def init_predictor(model_file, + params_file, + gpu_id=0, + use_trt=False, + trt_precision=0, + trt_use_static=False, + trt_static_dir=None, + collect_shape_info=False, + dynamic_shape_file=None): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, gpu_id) + if use_trt: + precision_mode = paddle.inference.PrecisionType.Float32 + if trt_precision == 1: + precision_mode = paddle.inference.PrecisionType.Half + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=10, + precision_mode=precision_mode, + use_static=trt_use_static, + use_calib_mode=False) + if collect_shape_info: + config.collect_shape_range_info(dynamic_shape_file) + else: + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) + if trt_use_static: + config.set_optim_cache_dir(trt_static_dir) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, voxels, coords, num_points_per_voxel): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "voxels": + input_tensor.reshape(voxels.shape) + input_tensor.copy_from_cpu(voxels.copy()) + elif name == "coords": + input_tensor.reshape(coords.shape) + input_tensor.copy_from_cpu(coords.copy()) + elif name == "num_points_per_voxel": + input_tensor.reshape(num_points_per_voxel.shape) + input_tensor.copy_from_cpu(num_points_per_voxel.copy()) + + # do the inference + predictor.run() + + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + if i == 0: + box3d_lidar = output_tensor.copy_to_cpu() + elif i == 1: + label_preds = output_tensor.copy_to_cpu() + elif i == 2: + scores = output_tensor.copy_to_cpu() + return box3d_lidar, label_preds, scores + + +if __name__ == '__main__': + args = parse_args() + + predictor = init_predictor(args.model_file, args.params_file, args.gpu_id, + args.use_trt, args.trt_precision, + args.trt_use_static, args.trt_static_dir, + args.collect_shape_info, args.dynamic_shape_file) + voxels, coords, num_points_per_voxel = preprocess( + args.lidar_file, args.num_point_dim, args.point_cloud_range, + args.voxel_size, args.max_points_in_voxel, args.max_voxel_num) + box3d_lidar, label_preds, scores = run(predictor, voxels, coords, + num_points_per_voxel) + + scan = np.fromfile(args.lidar_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + + # Obtain calibration information about Kitti + calib = Calibration(args.calib_file) + + # Plot box in lidar cloud + show_bev_with_boxes(pc_velo, box3d_lidar, scores, calib) diff --git a/demo/visualization_demo/dataset_vis_demo.py b/demo/visualization_demo/dataset_vis_demo.py new file mode 100644 index 00000000..2079590a --- /dev/null +++ b/demo/visualization_demo/dataset_vis_demo.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import cv2 + +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + +from demo.visualization_demo.vis_utils import Calibration, show_lidar_with_boxes, total_imgpred_by_conf_to_kitti_records, \ + make_imgpts_list, draw_mono_3d, show_bev_with_boxes + +pth = '../datasets/KITTI/training' # Kitti dataset path + +files = os.listdir(os.path.join(pth, 'image_2')) +files = sorted(files) + +mode = 'bev' + +assert mode in ['bev', 'image', 'pcd'], '' + +for img in files: + id = img[:-4] + label_file = os.path.join(pth, 'label_2', f'{id}.txt') + calib_file = os.path.join(pth, 'calib', f'{id}.txt') + img_file = os.path.join(pth, 'image_2', f'{id}.png') + pcd_file = os.path.join(pth, 'velodyne', f'{id}.bin') + + label_lines = open(label_file).readlines() + kitti_records_list = [line.strip().split(' ') for line in label_lines] + + if mode == 'pcd': + box3d_list = [] + for itm in kitti_records_list: + itm = [float(i) for i in itm[8:]] + # [z, -x, -y, w, l, h, ry] + box3d_list.append( + [itm[5], -itm[3], -itm[4], itm[1], itm[2], itm[0], itm[6]]) + box3d = np.asarray(box3d_list) + scan = np.fromfile(pcd_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(calib_file) + # Plot box in lidar cloud + # show_lidar_with_boxes(pc_velo, result['bboxes_3d'], result['confidences'], calib) + show_lidar_with_boxes(pc_velo, box3d, np.ones(box3d.shape[0]), calib) + + if mode == 'image': + kitti_records = np.array(kitti_records_list) + bboxes_2d, bboxes_3d, labels = camera_record_to_object(kitti_records) + # read origin image + img_origin = cv2.imread(img_file) + # to 8 points on image + itms = open(calib_file).readlines()[2] + P2 = itms[4:].strip().split(' ') + K = np.asarray([float(i) for i in P2]).reshape(3, 4)[:, :3] + imgpts_list = make_imgpts_list(bboxes_3d, K) + # draw smoke result to photo + draw_mono_3d(img_origin, imgpts_list) + + if mode == 'bev': + box3d_list = [] + for itm in kitti_records_list: + itm = [float(i) for i in itm[8:]] + # [z, -x, -y, w, l, h, ry] + box3d_list.append( + [itm[5], -itm[3], -itm[4], itm[1], itm[2], itm[0], itm[6]]) + box3d = np.asarray(box3d_list) + scan = np.fromfile(pcd_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + # Obtain calibration information about Kitti + calib = Calibration(calib_file) + # Plot box in lidar cloud (bev) + show_bev_with_boxes(pc_velo, box3d, np.ones(box3d.shape[0]), calib) diff --git a/demo/visualization_demo/img/bev.png b/demo/visualization_demo/img/bev.png new file mode 100644 index 00000000..88a6bde0 Binary files /dev/null and b/demo/visualization_demo/img/bev.png differ diff --git a/demo/visualization_demo/img/mono.jpg b/demo/visualization_demo/img/mono.jpg new file mode 100644 index 00000000..c2eb7348 Binary files /dev/null and b/demo/visualization_demo/img/mono.jpg differ diff --git a/demo/visualization_demo/img/pc.png b/demo/visualization_demo/img/pc.png new file mode 100644 index 00000000..2b101bac Binary files /dev/null and b/demo/visualization_demo/img/pc.png differ diff --git a/demo/visualization_demo/mono_vis_multi_frame_demo.py b/demo/visualization_demo/mono_vis_multi_frame_demo.py new file mode 100644 index 00000000..6f78e3b7 --- /dev/null +++ b/demo/visualization_demo/mono_vis_multi_frame_demo.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +import numpy as np + +from demo.visualization_demo.vis_utils import make_imgpts_list, draw_mono_3d, total_imgpred_by_conf_to_kitti_records + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('image') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/mono_vis_single_frame_demo.py b/demo/visualization_demo/mono_vis_single_frame_demo.py new file mode 100644 index 00000000..55b11a97 --- /dev/null +++ b/demo/visualization_demo/mono_vis_single_frame_demo.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import cv2 +import numpy as np + +from paddle.inference import Config, PrecisionType, create_predictor +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object +from demo.visualization_demo.vis_utils import get_img, get_ratio, total_pred_by_conf_to_kitti_records, make_imgpts_list, draw_mono_3d + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--image', dest='image', help='The image path', type=str, required=True) + parser.add_argument( + "--use_gpu", action='store_true', help="Whether use gpu.") + parser.add_argument( + "--use_trt", action='store_true', help="Whether use trt.") + parser.add_argument( + "--collect_dynamic_shape_info", + action='store_true', + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + dest='dynamic_shape_file', + help='The image path', + type=str, + default="dynamic_shape_info.txt") + return parser.parse_args() + + +def init_predictor(args): + config = Config(args.model_file, args.params_file) + config.enable_memory_optim() + if args.use_gpu: + config.enable_use_gpu(1000, 0) + else: + # If not specific mkldnn, you can set the blas thread. + # The thread num should not be greater than the number of cores in the CPU. + config.set_cpu_math_library_num_threads(4) + config.enable_mkldnn() + + if args.collect_dynamic_shape_info: + config.collect_shape_range_info(args.dynamic_shape_file) + elif args.use_trt: + allow_build_at_runtime = True + config.enable_tuned_tensorrt_dynamic_shape(args.dynamic_shape_file, + allow_build_at_runtime) + + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=1, + min_subgraph_size=3, + precision_mode=PrecisionType.Float32) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, image, K, down_ratio): + # copy img data to input tensor + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "images": + input_tensor.reshape(image.shape) + input_tensor.copy_from_cpu(image.copy()) + elif name == "trans_cam_to_img": + input_tensor.reshape(K.shape) + input_tensor.copy_from_cpu(K.copy()) + elif name == "down_ratios": + input_tensor.reshape(down_ratio.shape) + input_tensor.copy_from_cpu(down_ratio.copy()) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + + return results + + +if __name__ == '__main__': + args = parse_args() + pred = init_predictor(args) + # Listed below are camera intrinsic parameter of the kitti dataset + # If the model is trained on other datasets, please replace the relevant data + K = np.array([[[721.53771973, 0., 609.55932617], + [0., 721.53771973, 172.85400391], [0, 0, 1]]], np.float32) + + img, ori_img_size, output_size = get_img(args.image) + ratio = get_ratio(ori_img_size, output_size) + + results = run(pred, img, K, ratio) + + total_pred = results[0] + print(total_pred) + # convert pred to bboxes_2d, bboxes_3d + kitti_records = total_pred_by_conf_to_kitti_records(total_pred, conf=0.5) + bboxes_2d, bboxes_3d, labels = camera_record_to_object(kitti_records) + # read origin image + img_origin = cv2.imread(args.image) + # to 8 points on image + imgpts_list = make_imgpts_list(bboxes_3d, K[0]) + # draw smoke result to photo + draw_mono_3d(img_origin, imgpts_list) diff --git a/demo/visualization_demo/pcd_vis_multi_frame_demo.py b/demo/visualization_demo/pcd_vis_multi_frame_demo.py new file mode 100644 index 00000000..9fee1af0 --- /dev/null +++ b/demo/visualization_demo/pcd_vis_multi_frame_demo.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +import numpy as np + +from paddle3d.apis.infer import Infer +from paddle3d.apis.config import Config +from paddle3d.slim import get_qat_config +from paddle3d.utils.checkpoint import load_pretrained_model +from paddle3d.datasets.kitti.kitti_utils import camera_record_to_object + + +def parse_args(): + """ + """ + parser = argparse.ArgumentParser(description='Model evaluation') + # params of training + parser.add_argument( + "--config", dest="cfg", help="The config file.", default=None, type=str) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=None) + parser.add_argument( + '--model', + dest='model', + help='pretrained parameters of the model', + type=str, + default=None) + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=2) + parser.add_argument( + '--quant_config', + dest='quant_config', + help='Config for quant model.', + default=None, + type=str) + + return parser.parse_args() + + +def worker_init_fn(worker_id): + np.random.seed(1024) + + +def main(args): + """ + """ + if args.cfg is None: + raise RuntimeError("No configuration file specified!") + + if not os.path.exists(args.cfg): + raise RuntimeError("Config file `{}` does not exist!".format(args.cfg)) + + cfg = Config(path=args.cfg, batch_size=args.batch_size) + print(args.cfg) + if cfg.val_dataset is None: + raise RuntimeError( + 'The validation dataset is not specified in the configuration file!' + ) + elif len(cfg.val_dataset) == 0: + raise ValueError( + 'The length of validation dataset is 0. Please check if your dataset is valid!' + ) + + dic = cfg.to_dict() + batch_size = dic.pop('batch_size') + dic.update({ + 'dataloader_fn': { + 'batch_size': batch_size, + 'num_workers': args.num_workers, + 'worker_init_fn': worker_init_fn + } + }) + + if args.quant_config: + quant_config = get_qat_config(args.quant_config) + cfg.model.build_slim_model(quant_config['quant_config']) + + if args.model is not None: + load_pretrained_model(cfg.model, args.model) + dic['checkpoint'] = None + dic['resume'] = False + else: + dic['resume'] = True + + infer = Infer(**dic) + infer.infer('pcd') + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/demo/visualization_demo/pcd_vis_single_frame_demo.py b/demo/visualization_demo/pcd_vis_single_frame_demo.py new file mode 100644 index 00000000..501061ba --- /dev/null +++ b/demo/visualization_demo/pcd_vis_single_frame_demo.py @@ -0,0 +1,188 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +from paddle.inference import Config, create_predictor +from paddle3d.ops.iou3d_nms_cuda import nms_gpu +from demo.visualization_demo.vis_utils import preprocess, Calibration, show_lidar_with_boxes + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--lidar_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + '--calib_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + "--num_point_dim", + type=int, + default=4, + help="Dimension of a point in the lidar file.") + parser.add_argument( + "--point_cloud_range", + dest='point_cloud_range', + nargs='+', + help="Range of point cloud for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--voxel_size", + dest='voxel_size', + nargs='+', + help="Size of voxels for voxelize operation.", + type=float, + default=None) + parser.add_argument( + "--max_points_in_voxel", + type=int, + default=100, + help="Maximum number of points in a voxel.") + parser.add_argument( + "--max_voxel_num", + type=int, + default=12000, + help="Maximum number of voxels.") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.") + parser.add_argument( + "--use_trt", + type=int, + default=0, + help="Whether to use tensorrt to accelerate when using gpu.") + parser.add_argument( + "--trt_precision", + type=int, + default=0, + help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.") + parser.add_argument( + "--trt_use_static", + type=int, + default=0, + help="Whether to load the tensorrt graph optimization from a disk path." + ) + parser.add_argument( + "--trt_static_dir", + type=str, + help="Path of a tensorrt graph optimization directory.") + parser.add_argument( + "--collect_shape_info", + type=int, + default=0, + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--dynamic_shape_file", + type=str, + default="", + help="Path of a dynamic shape file for tensorrt.") + + return parser.parse_args() + + +def init_predictor(model_file, + params_file, + gpu_id=0, + use_trt=False, + trt_precision=0, + trt_use_static=False, + trt_static_dir=None, + collect_shape_info=False, + dynamic_shape_file=None): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, gpu_id) + if use_trt: + precision_mode = paddle.inference.PrecisionType.Float32 + if trt_precision == 1: + precision_mode = paddle.inference.PrecisionType.Half + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=10, + precision_mode=precision_mode, + use_static=trt_use_static, + use_calib_mode=False) + if collect_shape_info: + config.collect_shape_range_info(dynamic_shape_file) + else: + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) + if trt_use_static: + config.set_optim_cache_dir(trt_static_dir) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, voxels, coords, num_points_per_voxel): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + if name == "voxels": + input_tensor.reshape(voxels.shape) + input_tensor.copy_from_cpu(voxels.copy()) + elif name == "coords": + input_tensor.reshape(coords.shape) + input_tensor.copy_from_cpu(coords.copy()) + elif name == "num_points_per_voxel": + input_tensor.reshape(num_points_per_voxel.shape) + input_tensor.copy_from_cpu(num_points_per_voxel.copy()) + + # do the inference + predictor.run() + + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + if i == 0: + box3d_lidar = output_tensor.copy_to_cpu() + elif i == 1: + label_preds = output_tensor.copy_to_cpu() + elif i == 2: + scores = output_tensor.copy_to_cpu() + return box3d_lidar, label_preds, scores + + +if __name__ == '__main__': + args = parse_args() + + predictor = init_predictor(args.model_file, args.params_file, args.gpu_id, + args.use_trt, args.trt_precision, + args.trt_use_static, args.trt_static_dir, + args.collect_shape_info, args.dynamic_shape_file) + voxels, coords, num_points_per_voxel = preprocess( + args.lidar_file, args.num_point_dim, args.point_cloud_range, + args.voxel_size, args.max_points_in_voxel, args.max_voxel_num) + box3d_lidar, label_preds, scores = run(predictor, voxels, coords, + num_points_per_voxel) + + scan = np.fromfile(args.lidar_file, dtype=np.float32) + pc_velo = scan.reshape((-1, 4)) + + # Obtain calibration information about Kitti + calib = Calibration(args.calib_file) + + # Plot box in lidar cloud + show_lidar_with_boxes(pc_velo, box3d_lidar, scores, calib) diff --git a/demo/visualization_demo/readme.md b/demo/visualization_demo/readme.md new file mode 100644 index 00000000..85bd098c --- /dev/null +++ b/demo/visualization_demo/readme.md @@ -0,0 +1,159 @@ +## 激光雷达点云/BEV和相机图像的3D可视化示例 +### 环境配置 +按照 [官方文档](https://github.com/PaddlePaddle/Paddle3D/blob/develop/docs/installation.md) 安装paddle3D依赖,然后安装`mayavi`用于激光点云可视化 +``` +pip install vtk==8.1.2 +pip install mayavi==4.7.4 +pip install PyQt5 +``` +### 相机图像的3D可视化示例 +相机图像的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧图像的3D可视化示例程序`mono_vis_single_frame_demo.py`和多帧图像的3D可视化示例程序`mono_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`mono_vis_single_frame_demo.py`和`mono_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`mono_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`mono_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`mono_vis_single_frame_demo.py`使用方式如下: +``` +cd demo/visualization_demo +python mono_vis_single_frame_demo.py \ + --model_file model/smoke.pdmodel \ + --params_file model/smoke.pdiparams \ + --image data/image_2/000008.png +``` +`--model_file`和`--params_file`是使用的模型参数文件对应的路径 + +`--image`则是输入图像的路径 + +`mono_vis_multi_frame_demo.py`使用方式如下: + +``` +python mono_vis_multi_frame_demo.py \ + --config configs/smoke/smoke_dla34_no_dcn_kitti.yml \ + --model demo/smoke.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + + +最终的单目可视化输出如下: + +![](img/mono.jpg) +### 激光雷达点云的3D可视化示例 +激光雷达点云的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云的3D可视化示例程序`pcd_vis_single_frame_demo.py`和多帧激光雷达点云的3D可视化示例程序`pcd_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`pcd_vis_single_frame_demo.py`和`pcd_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`pcd_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`pcd_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`pcd_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python pcd_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` + +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`pcd_vis_multi_frame_demo.py`使用方式如下: + +``` +python pcd_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达点云可视化输出如下: + +![](img/pc.png) +### 激光雷达BEV的3D可视化示例 +激光雷达BEV的3D可视化文件保存在`demo/visualization_demo/`下,提供了单帧激光雷达点云BEV的3D可视化示例程序`bev_vis_single_frame_demo.py`和多帧激光雷达点云BEV的3D可视化示例程序`bev_vis_multi_frame_demo.py`。两者使用的可视化接口相同,对应的代码在`paddle3d.apis.infer`中。 + +`bev_vis_single_frame_demo.py`和`bev_vis_multi_frame_demo.py`的实现方法不同,以提供更多的可视化示例方法。其中`bev_vis_single_frame_demo.py`利用paddle推理部署的方式完成可视化,`bev_vis_multi_frame_demo.py`可视化通过对图像构建`dataloader`来完成逐帧读取和推理。 + +`bev_vis_single_frame_demo.py`使用方式如下: + +``` +cd demo/visualization_demo +python bev_vis_single_frame_demo.py \ + --model_file model/pointpillars.pdmodel \ + --params_file model/pointpillars.pdiparams \ + --lidar_file data/velodyne/000008.bin \ + --calib_file data/calib/000008.txt \ + --point_cloud_range 0 -39.68 -3 69.12 39.68 1 \ + --voxel_size .16 .16 4 \ + --max_points_in_voxel 32 \ + --max_voxel_num 40000 +``` +`--model_file`和`--params_file` 是使用的模型参数文件对应的路径 + +`--lidar_file` `--calib_file` 是激光雷达点云的路径和对应的校准文件路径 + +`--point_cloud_range` 表示激光雷达点云的`(x,y,z)`范围区间 + +`--voxel_size` 表示进行voxel处理时的尺寸大小 + +`--max_points_in_voxel` 每个voxel中最大的激光点云数目 + +`--max_voxel_num` voxel的最大数目 + +`bev_vis_multi_frame_demo.py`使用方式如下: + +``` +python bev_vis_multi_frame_demo.py \ + --config configs/pointpillars/pointpillars_xyres16_kitti_car.yml \ + --model demo/pointpillars.pdparams \ + --batch_size 1 +``` + +`--config` 是模型配置文件路径 + +`--model` 是使用的模型参数文件对应的路径 + +`--batch_size` 是推理的batch数 + +最终的激光雷达BEV可视化输出如下: + +![](img/bev.png) + +### 数据集和LOG文件的可视化接口 +可视化接口对应的代码在`paddle3d.apis.infer`中,提供了一种调用示例 + +``` +cd demo/visualization_demo +python dataset_vis_demo.py +``` + +--- +如果遇到如下问题,可参考Ref1和Ref2的解决方案: + +`qt.qpa.plugin: Could not load the Qt Platform plugin 'xcb' in ..` + +[Ref1](https://blog.csdn.net/qq_39938666/article/details/120452028?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-2-120452028-blog-112303826.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=3) +& [Ref2](https://blog.csdn.net/weixin_41794514/article/details/128578166?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EYuanLiJiHua%7EPosition-3-128578166-blog-119480436.pc_relevant_landingrelevant) diff --git a/demo/visualization_demo/vis_utils.py b/demo/visualization_demo/vis_utils.py new file mode 100644 index 00000000..99d1019d --- /dev/null +++ b/demo/visualization_demo/vis_utils.py @@ -0,0 +1,719 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import numba +import numpy as np +import mayavi.mlab as mlab + +from paddle3d.transforms.target_generator import encode_label + + +class Calibration(object): + ''' Calibration matrices and utils + 3d XYZ in