Press "Enter" to skip to content

基于深度学习的手部21类关键点检测

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

模型推理

 

运行命令:
python inference.py

 

#-*-coding:utf-8-*-

 

import os

 

import argparse

 

import torch

 

import torch.nn as nn

 

import numpy as np

 

import time

 

import datetime

 

import os

 

import math

 

from datetime import datetime

 

import cv2

 

import torch.nn.functional as F

 

from models.resnet import ,resnet34,resnet50,resnet101

 

from models.squeezenet import squeezenet1_1,squeezenet1_0

 

from models.shufflenetv2 import ShuffleNetV2

 

from models.shufflenet import ShuffleNet

 

from models.mobilenetv2 import MobileNetV2

 

from torchvision.models import shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0

 

from models.rexnetv1 import ReXNetV1

 

from utils.common_utils import *

 

import copy

 

from hand_data_iter.datasets import draw_bd_handpose

 

if __name__ == “__main__”:

 

parser = argparse.ArgumentParser(description=’ Project Hand Pose Inference’)

 

parser.add_argument(‘–model_path’, type=str, default = ‘./w/resnet50_2021-418.pth’,

 

help = ‘model_path’) #
模型路径

 

parser.add_argument(‘–model’, type=str, default = ‘resnet_50′,

 

help = ”’model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2

 

shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0,ReXNetV1”’) #
模型类型

 

parser.add_argument(‘–num_classes’, type=int , default = 42,

 

help = ‘num_classes’) #

21

(x,y)*2 = 42

 

parser.add_argument(‘–GPUS’, type=str, default = ‘0’,

 

help = ‘GPUS’) # GPU
选择

 

parser.add_argument(‘–test_path’, type=str, default = ‘./image/’,

 

help = ‘test_path’) #
测试图片路径

 

parser.add_argument(‘–img_size’, type=tuple , default = (256,256),

 

help = ‘img_size’) #
输入模型图片尺寸

 

parser.add_argument(‘–vis’, type=bool , default = True,

 

help = ‘vis’) #
是否可视化图片

 

print(‘\n/******************* {} ******************/\n’.format(parser.description))

 

#————————————————————————–

 

ops = parser.parse_args()#
解析添加参数

 

#————————————————————————–

 

print(‘———————————-‘)

 

unparsed = vars(ops) # parse_args()
方法的返回值为
namespace
,用
vars()
内建函数化为字典

 

for key in unparsed.keys():

 

print(‘{} : {}’.format(key,unparsed[key]))

 

#—————————————————————————

 

os.environ[‘CUDA_VISIBLE_DEVICES’] = ops.GPUS

 

test_path =  ops.test_path # 货币代码

测试图片文件夹路径

 

#—————————————————————-
构建模型

 

print(‘use model : %s’%(ops.model))

 

if ops.model == ‘resnet_50’:

 

model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0])

 

elif ops.model == ‘resnet_18’:

 

model_ = resnet18(num_classes = ops.num_classes,img_size=ops.img_size[0])

 

elif ops.model == ‘resnet_34’:

 

model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0])

 

elif ops.model == ‘resnet_101’:

 

model_ = resnet101(num_classes = ops.num_classes,img_size=ops.img_size[0])

 

elif ops.model == “squeezenet1_0”:

 

model_ = squeezenet1_0(num_classes=ops.num_classes)

 

elif ops.model == “squeezenet1_1”:

 

model_ = squeezenet1_1(num_classes=ops.num_classes)

 

elif ops.model == “shufflenetv2”:

 

model_ = ShuffleNetV2(ratio=1., num_classes=ops.num_classes)

 

elif ops.model == “shufflenet_v2_x1_5”:

 

model_ = shufflenet_v2_x1_5(pretrained=False,num_classes=ops.num_classes)

 

elif ops.model == “shufflenet_v2_x1_0”:

 

model_ = shufflenet_v2_x1_0(pretrained=False,num_classes=ops.num_classes)

 

elif ops.model == “shufflenet_v2_x2_0”:

 

model_ = shufflenet_v2_x2_0(pretrained=False,num_classes=ops.num_classes)

 

elif ops.model == “shufflenet”:

 

model_ = ShuffleNet(num_blocks = [2,4,2], num_classes=ops.num_classes, groups=3)

 

elif ops.model == “mobilenetv2”:

 

model_ = MobileNetV2(num_classes=ops.num_classes)

 

elif ops.model == “ReXNetV1”:

 

model_ = ReXNetV1( width_mult=1.0, depth_mult=1.0, num_classes=ops.num_classes)

 

use_cuda = torch.cuda.is_available()

 

device = torch.device(“cuda:0” if use_cuda else “cpu”)

 

model_ = model_.to(device)

 

model_.eval() #
设置为前向推断模式

 

# print(model_)#
打印模型结构

 

#
加载测试模型

 

if os.access(ops.model_path,os.F_OK):# checkpoint

 

chkpt = torch.load(ops.model_path, map_location=device)

 

model_.load_state_dict(chkpt)

 

print(‘load test model : {}’.format(ops.model_path))

 

#—————————————————————-
预测图片

 

”’
建议 检测手
bbox
后,
crop
手图片的预处理方式:

 

# img
为原图

 

x_min,y_min,x_max,y_max,score = bbox

 

w_ = max(abs(x_max-x_min),abs(y_max-y_min))

 

w_ = w_*1.1

 

x_mid = (x_max+x_min)/2

 

y_mid = (y_max+y_min)/2

 

x1,y1,x2,y2 = int(x_mid-w_/2),int(y_mid-w_/2),int(x_mid+w_/2),int(y_mid+w_/2)

 

x1 = np.clip(x1,0,img.shape[1]-1)

 

x2 = np.clip(x2,0,img.shape[1]-1)

 

y1 = np.clip(y1,0,img.shape[0]-1)

 

y2 = np.clip(y2,0,img.shape[0]-1)

 

”’

 

with torch.no_grad():

 

idx = 0

 

for file in os.listdir(ops.test_path):

 

if ‘.jpg’ not in file:

 

continue

 

idx += 1

 

print(‘{}) image : {}’.format(idx,file))

 

img = cv2.imread(ops.test_path + file)

 

img_width = img.shape[1]

 

img_height = img.shape[0]

 

#
输入图片预处理

 

img_ = cv2.resize(img, (ops.img_size[1],ops.img_size[0]), interpolation = cv2.INTER_CUBIC)

 

img_ = img_.astype(np.float32)

 

img_ = (img_-128.)/256.

 

img_ = img_.transpose(2, 0, 1)

 

img_ = torch.from_numpy(img_)

 

img_ = img_.unsqueeze_(0)

 

if use_cuda:

 

img_ = img_.cuda()  # (bs, 3, h, w)

 

pre_ = model_(img_.float()) #
模型推理

 

output = pre_.cpu().detach().numpy()

 

output = np.squeeze(output)

 

pts_hand = {} #
构建关键点连线可视化结构

 

for i in range(int(output.shape[0]/2)):

 

x = (output[i*2+0]*float(img_width))

 

y = (output[i*2+1]*float(img_height))

 

pts_hand[str(i)] = {}

 

pts_hand[str(i)] = {

 

“x”:x,

 

“y”:y,

 

}

 

draw_bd_handpose(img,pts_hand,0,0) #
绘制关键点连线

 

#————-
绘制关键点

 

for i in range(int(output.shape[0]/2)):

 

x = (output[i*2+0]*float(img_width))

 

y = (output[i*2+1]*float(img_height))

 

cv2.circle(img, (int(x),int(y)), 3, (255,50,60),-1)

 

cv2.circle(img, (int(x),int(y)), 1, (255,150,180),-1)

 

if ops.vis:

 

cv2.namedWindow(‘image’,0)

 

cv2.imshow(‘image’,img)

 

cv2.imwrite(file, img)

 

if cv2.waitKey(600) == 27 :

 

break

 

cv2.destroyAllWindows()

 

print(‘well done ‘)

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注