使用PaddleDetection实现竹签计数功能,实现效果如下:

一、数据集准备与格式转换

数据集来源:

https://ieee-dataport.org/documents/mobilenetv3-cbam-bamboo-stick-counting

该数据集格式为tfrecords,需要转换为COCO或者VOC

转换过程的核心代码如下:

raw_image_dataset = tf.data.TFRecordDataset(filenames)

for raw_record in raw_image_dataset:
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())

    filename = str(example.features.feature['image/filename'].bytes_list.value[0], encoding='utf-8')
    

使用TFRecordDataset 来包裹tfrecords文件,并使用Example来读取每一个数据对象,最后使用feature名称来获取对应的属性值。

对于该数据集,feature属性名称如下:

image_feature_description = {
    "image/encoded": tf.io.FixedLenFeature([], tf.string),
    "image/filename": tf.io.FixedLenFeature([], tf.string),
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/object/bbox/xmax': tf.io.FixedLenFeature([], tf.float32),
    'image/object/bbox/xmin': tf.io.FixedLenFeature([], tf.float32),
    'image/object/bbox/ymax': tf.io.FixedLenFeature([], tf.float32),
    'image/object/bbox/ymin': tf.io.FixedLenFeature([], tf.float32),
    'image/object/class/label': tf.io.FixedLenFeature([], tf.int64),
}

其中,图片字节流feature名为image/encoded

遍历完每个数据对象后,就可以提取图片对象。

图片对象提取并转换为PIL格式的过程如下:

image_bytes = example.features.feature['image/encoded'].bytes_list.value[0]
img = Image.open(io.BytesIO(image_bytes))

最后,构造COCO的格式或者VOC格式的dict,并转换为json或者xml标签。

在该项目中,我使用VOC格式数据集,将dict生成xml标签的工具来自:

https://github.com/canerkaraguler/Json2PascalVOC

转换完成后,查看效果。

PS:该数据集中部分标签的xy坐标反了,需要自己手动处理。

二、方案选择

对数据集进行分析。

图片的宽高差距较大,且竹签目标框属于小目标,因此最好切图训练。这里推荐使用sahi

https://github.com/obss/sahi

sahi不支持VOC格式的直接转换,因此可以使用PaddleDetection自带的tools/x2coco.py,将VOC格式的数据集转换为COCO格式的数据集。然后,也可以在数据集转换阶段就一步转换为COCO格式。

python tools/x2coco.py \
        --dataset_type voc \
        --voc_anno_dir dataset/zhuzi/annotations \
        --voc_anno_list dataset/zhuzi/ImageSets/Main/val.txt \
        --voc_label_list dataset/zhuzi/label_list.txt \
        --voc_out_name dataset/zhuzi/annotations/coco_val.json

之后,使用tools/slice_image.py,将COCO格式的数据集切图。

子图大小为320*320,重叠率为0.25。

python tools/slice_image.py --image_dir dataset/zhuzi/images/ --json_path dataset/zhuzi/annotations/coco_val.json --output_dir dataset/zhuzi_sliced --slice_size 320 --overlap_ratio 0.25

而在模型的选择方面,由于该业务较为简单,可以只考虑实际部署平台。

为了使该模型可以部署到微信小程序,我使用的是超轻量识别模型Picodet,搭配Paddle.js

https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/picodet

若是服务端部署,可以使用PPYoloE

该文使用的是PPYoloE

注意:建议使用PPYoloE,使用Paddle.jsPicodet部署至小程序端的方案不可行。但是模型大同小异,只需要注意参数的配置。

三、参数设置

1. 新建数据集参数

/PaddleDetection/configs/datasets/coco_detection_zhuzi.yml

metric: COCO
num_classes: 1 # 无须像yolo一样,有多少种写多少种即可

TrainDataset:
  !COCODataSet
    dataset_dir: dataset/zhuzi_sliced
    image_dir: coco_train_images_320_025
    anno_path: coco_train_320_025.json
    data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']

EvalDataset:
  !COCODataSet
    dataset_dir: dataset/zhuzi_sliced
    image_dir: coco_val_images_320_025
    anno_path: coco_val_320_025.json
    data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']

TestDataset:
  !ImageFolder
    anno_path: dataset/zhuzi/label_list.txt

2.修改主配置文件

/PaddleDetection/configs/smalldet/ppyoloe_zhuzi.yml

_BASE_: [
  '../datasets/coco_detection_zhuzi.yml',
  '../runtime.yml',
  '../ppyoloe/_base_/optimizer_300e.yml',
  '../ppyoloe/_base_/ppyoloe_crn.yml',
  '../ppyoloe/_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_80e_sliced_visdrone_640_025/model_final

pretrain_weights: pretrained_model/ppyoloe_crn_l_80e_sliced_visdrone_640_025.pdparams
depth_mult: 1.0
width_mult: 1.0


TrainReader:
  batch_size: 8 #默认为64,8台gpu。由于只有1台gpu,因此除以8

EvalReader:
  batch_size: 1

TestReader:
  batch_size: 1
  fuse_normalize: True


epoch: 80
LearningRate:
  base_lr: 0.0125
  schedulers:
    - !CosineDecay
      max_epochs: 96
    - !LinearWarmup
      start_factor: 0.
      epochs: 1

PPYOLOEHead:
  static_assigner_epoch: -1
  nms:
    name: MultiClassNMS
    nms_top_k: 10000
    keep_top_k: 500 # 默认为100,若一把竹签数量超过100,需要修改该部分
    score_threshold: 0.01
    nms_threshold: 0.5

3.修改reader配置文件

/PaddleDetection/configs/ppyoloe/_base_/ppyoloe_reader.yml

worker_num: 4
eval_height: &eval_height 2048 # 根据业务调整
eval_width: &eval_width 2048 # 根据业务调整,和高度相同
eval_size: &eval_size [*eval_height, *eval_width]

TrainReader:
  sample_transforms:
    - Decode: {}
    - RandomDistort: {}
    - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
    - RandomCrop: {}
    - RandomFlip: {}
  batch_transforms:
    - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
    - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
    - Permute: {}
    - PadGT: {}
  batch_size: 8
  shuffle: true
  drop_last: true
  use_shared_memory: true
  collate_batch: true

EvalReader:
  sample_transforms:
    - Decode: {}
    - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
    - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
    - Permute: {}
  batch_size: 2

TestReader:
  inputs_def:
    image_shape: [3, *eval_height, *eval_width]
  sample_transforms:
    - Decode: {}
    - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
    - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
    - Permute: {}
  batch_size: 1

四、训练

export CUDA_VISIBLE_DEVICES=0
python tools/train.py -c configs/smalldet/ppyoloe_zhuzi.yml --use_vdl=True --vdl_log_dir=./sliced_visdrone/ --eval

五、预测

python tools/infer.py -c configs/smalldet/ppyoloe_zhuzi.yml -o weights=output/ppyoloe_zhuzi/model_final.pdparams --infer_img=test_img/IMG_8849.JPG  --output_dir='output/eval'

六、部署

为了灵活性,我使用PaddlePaddle和flask。

可参考该项目:

https://github.com/JackDance/paddle-flask-deploy

更换其中的模型与配置文件即可。

由于前端需要显示识别结果,这里调用七牛云,保存识别后的图片。

上传图片的示例如下,key为文件名称,path为图片路径。

def upload_img(key,path):
    q = qiniu.Auth(QINIU_AK, QINIU_SK)
    token = q.upload_token(QINIU_BUCKET, key, 3600)
    ret,info = qiniu.put_file(token, key ,path)

七、总结

本项目使用PaddleDetection实现了竹签计数,难度不大,是个新手项目。

唯一需要注意就是数据集格式需要满足要求,本项目大部分时间花在了数据集的格式处理上。