Press "Enter" to skip to content

【RotNet 自监督学习】预测图像旋转角度

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

 

RotNet:预测图像旋转

 

论文导读

 

RotNet 通过预测图像旋转进行自监督学习

 

这是2018年ICLR发表的一篇论文,被引用超过1100次。论文的想法来源于:如果某人不了解图像中描绘的对象的概念,则他无法识别应用于图像的旋转。

 

在这篇文章中,我们回顾了巴黎科技大学(University Paris-Est)通过预测图像旋转进行的无监督表示学习。使用RotNet
通过训练ConvNets
来学习图像特征,以识别应用于作为输入的图像的2d旋转。通过这种方法,无监督的预训练AlexNet模型达到了54.4%的mAP,仅比有监督的AlexNet低2.4点。

 

图像旋转预测框架

 

 

 

给定四种可能的几何变换,即0、90、180和270度旋转,卷积网络模型F(:)被训练来识别输入的图像应用了哪个旋转。

 

Fy(Xy)
是模型F(:)
预测的旋转变换 y 的概率,它的输入是一个已经被旋转变换的图像,输出图片的旋转角度。

 

为了成功地预测图像的旋转,ConvNet
模型必须学习定位图像中的显着目标,识别它们的方向和对象类型,然后将对象方向与原始图像进行关联。

 

由经过训练的 AlexNet 模型生成的注意力图(a)识别对象(监督)和(b)识别图像旋转(自监督)。

 

上述注意力图是根据卷积层的每个空间单元的激活幅度计算的,本质上反映了网络将大部分焦点放在何处以对输入图像进行分类。

 

途中可以看到,监督模型和自监督模型似乎都关注大致相同的图像区域。

 

旋转拖动验证码解决方案

 

曾几何时,你是否被一个旋转验证码而困扰,没错今日主题——旋转验证码

 

当进行模拟登录时,图片验证码是一大难点。

 

不过有了RotNet
,这一问题便迎刃而解旋转拖动验证码解决方案

 

两种思路

 

图像旋转考虑两种思路:回归与分类

回归
:预测数值结果范围是0-360°.
分类
:预测360个类别,模型预测输出哪个类别的概率最大.

定义卷积神经网络训练旋转图片集,进行预测图片旋转的角度。

 

大数据应用赛

 

大数据应用赛:计算机视觉在众多的AI中应用广泛,比如自动驾驶、视觉导航、目标检测、目标识别等等,无一不关系到计算机视觉,而图像技术往往能帮助计算机视觉得到提升,比如随机剪裁、随机旋转、图像模糊等等图像手段。图像技术对计算机视觉的重要性则不言而喻,故本次大数据应用赛的赛题为图像扶正挑战。

 

 

 

卷积神经网络

 

分类代码:

 

# number of convolutional filters to use
nb_filters = 64
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)
# number of classes
nb_classes = 360
# model definition
input = Input(shape=(img_rows, img_cols, img_channels))
x = Conv2D(nb_filters, kernel_size, activation='relu')(input)
x = Conv2D(nb_filters, kernel_size, activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.25)(x)
x = Dense(nb_classes, activation='softmax')(x)
model = Model(inputs=input, outputs=x)
model.summary()

 

模型编译

 

# model compilation
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=[angle_error])

 

训练参数

 

# training parameters
batch_size = 128
nb_epoch = 50

 

# callbacks
checkpointer = ModelCheckpoint(
    filepath=os.path.join(output_folder, model_name + '.hdf5'),
    save_best_only=True
)
early_stopping = EarlyStopping(patience=2)
tensorboard = TensorBoard()

 

模型训练

 

# training loop
model.fit_generator(
    RotNetDataGenerator(
        X_train,
        batch_size=batch_size,
        preprocess_func=binarize_images,
        shuffle=True
    ),
    steps_per_epoch=nb_train_samples / batch_size,
    epochs=nb_epoch,
    validation_data=RotNetDataGenerator(
        X_test,
        batch_size=batch_size,
        preprocess_func=binarize_images
    ),
    validation_steps=nb_test_samples / batch_size,
    verbose=1,
    callbacks=[checkpointer, early_stopping, tensorboard]
)

 

完整代码

 

"""
@Author: ZS
@CSDN  : https://zsyll.blog.csdn.net/
@Time  : 2021/11/20 10:48
"""
from __future__ import print_function
import os
import sys
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras.applications.resnet50 import ResNet50
from keras.applications.imagenet_utils import preprocess_input
from keras.models import Model
from keras.layers import Dense, Flatten
from keras.optimizers import SGD
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import angle_error, RotNetDataGenerator
from getImagePath import getPath
data_path = r'./data/image/'
train_filenames, test_filenames = getPath(data_path)
print(len(train_filenames), 'train samples')
print(len(test_filenames), 'test samples')
model_name = 'rotnet_resnet50'
# 分类数量
nb_classes = 360
# input image shape
input_shape = (320, 320, 3)
# 加载基础模型
base_model = ResNet50(weights='imagenet', include_top=False,
                      input_shape=input_shape)
# 添加分类层
x = base_model.output
x = Flatten()(x)
final_output = Dense(nb_classes, activation='softmax', name='fc360')(x)
# 创建新的模型
model = Model(inputs=base_model.input, outputs=final_output)
model.summary()
# 模型编译
model.compile(loss='categorical_crossentropy',
              optimizer=SGD(lr=0.01, momentum=0.9),
              metrics=[angle_error])
# 训练参数
batch_size = 64
nb_epoch = 20
output_folder = 'models'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
# callbacks
monitor = 'val_angle_error'
checkpointer = ModelCheckpoint(
    filepath=os.path.join(output_folder, model_name + '.hdf5'),
    monitor=monitor,
    save_best_only=True
)
reduce_lr = ReduceLROnPlateau(monitor=monitor, patience=3)
early_stopping = EarlyStopping(monitor=monitor, patience=5)
tensorboard = TensorBoard()
# 训练模型
model.fit_generator(
    RotNetDataGenerator(
        train_filenames,
        input_shape=input_shape,
        batch_size=batch_size,
        preprocess_func=preprocess_input,
        crop_center=True,
        crop_largest_rect=True,
        shuffle=True
    ),
    steps_per_epoch=len(train_filenames) / batch_size,
    epochs=nb_epoch,
    validation_data=RotNetDataGenerator(
        test_filenames,
        input_shape=input_shape,
        batch_size=batch_size,
        preprocess_func=preprocess_input,
        crop_center=True,
        crop_largest_rect=True
    ),
    validation_steps=len(test_filenames) / batch_size,
    callbacks=[checkpointer, reduce_lr, early_stopping, tensorboard],
    workers=10
)

 

模型调用

 

# import区域,sys为必须导入,其他根据需求导入
from __future__ import print_function
import os
import sys
import random
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
from mykeras.applications.imagenet_utils import preprocess_input
from mykeras.models import load_model
from utils import display_examples, RotNetDataGenerator, angle_error
import warnings
warnings.filterwarnings("ignore")
from tensorflow.keras import layers
# 代码区,根据需求写
class FileSequence(keras.utils.Sequence):
    def __init__(self,filenames,batch_size,filefunc,fileargs=(),labels=None,labelfunc=None,labelargs=(),shuffle=False):
        if labels: assert len(filenames) == len(labels)
        self.filenames  = filenames
        self.batch_size = batch_size
        self.filefunc   = filefunc
        self.fileargs   = fileargs
        self.labels     = labels
        self.labelfunc  = labelfunc
        self.labelargs  = labelargs  
        if shuffle:
            idx_list = list(range(len(self.filenames)))
            random.shuffle(idx_list)
            self.filenames = [self.filenames[idx] for idx in idx_list]
            if self.labels: self.labels = [self.labels[idx] for idx in idx_list]
    def __len__(self):
        return int(np.ceil(len(self.filenames) / float(self.batch_size)))
    def __getitem__(self, idx):
        batch_filenames = self.filenames[idx * self.batch_size: (idx+1) * self.batch_size]
        
        files = []
        for filename in batch_filenames:
            # tf.print(filename)
            file = self.filefunc(filename,*self.fileargs)
            files.append(file)
        if self.labels:
            batch_labels = self.labels[idx * self.batch_size: (idx+1) * self.batch_size]
            if self.labelfunc:
                return np.array(files), self.labelfunc(batch_labels,*self.labelargs)
            else:
                return np.array(files), batch_labels
        else:
            return np.array(files)
def fillWhite(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert (h < size) and (w < size)
    fillImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        sh = random.randint(0,size-h)
        sw = random.randint(0,size-w)
        fillImg[sh:sh+h,sw:sw+w,...] = img
    elif mode == "centre" or mode == "center":
        fillImg[(size-h)//2:(size+h)//2,(size-w)//2:(size+w)//2,...] = img
    else:
        fillImg[:h,:w,...] = img
    return fillImg
def cropImg(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert (h >= size) and (w >= size)
    if mode == "random":
        sh = random.randint(0,h-size)
        sw = random.randint(0,w-size)
        cropImg = img[sh:sh+size,sw:sw+size,...]
    elif mode == "centre" or mode == "center":
        cropImg = img[(h-size)//2:(h+size)//2,(w-size)//2:(w+size)//2,...]
    else:
        cropImg = img[:size,:size,...]
    return cropImg
def fillCrop(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    assert ((h >= size) and (w < size)) or ((h < size) and (w >= size))
    fillcropImg = np.zeros(shape=(size,size,c))
    if mode == "random":
        if (h >= size) and (w < size):
            sh = random.randint(0,h-size)
            sw = random.randint(0,size-w)
            fillcropImg[:,sw:sw+w,:] = img[sh:sh+size,...]
        else:
            sh = random.randint(0,size-h)
            sw = random.randint(0,w-size)
            fillcropImg[sh:sh+h,...] = img[:,sw:sw+size,:]
    elif mode == "centre" or mode == "center":
        if (h >= size) and (w < size):
            fillcropImg[:,(size-w)//2:(size+w)//2,:] = img[(h-size)//2:(h+size)//2,...]
        else:
            fillcropImg[(size-h)//2:(size+h)//2,...] = img[:,(w-size)//2:(w+size)//2,:]
    else:
        if (h >= size) and (w < size):
            fillcropImg[:,:size,:] = img[:size,...]
        else:
            fillcropImg[:size,...] = img[:,:size,:]
    return fillcropImg
def resizeImg(img,size,mode=None):
    if len(img.shape) == 2: img = img.reshape(*img.shape,-1)
    assert len(img.shape) == 3
    h, w, c = img.shape
    if (h < size) and (w < size): return fillWhite(img,size,mode)
    elif (h >= size) and (w >= size): return cropImg(img,size,mode)
    else: return fillCrop(img,size,mode)
def filefunc(filename,mode):
    tf.print(filename)
    img = cv2.imread(filename)
    if not isinstance(img,np.ndarray):
        tf.print(filename)
    h, w, c = img.shape
    if (h >=256) or (w >= 256):
        img = resizeImg(img,256,mode)
        img = cv2.resize(img,(64,64))
    elif (h >=128) or (w >= 128):
        img = resizeImg(img,128,mode)
        img = cv2.resize(img,(64,64))
    else:
        img = resizeImg(img,64,mode)
    return img    
# 主函数,格式固定,to_pred_dir为预测所在文件夹,result_save_path为预测结果生成路径
# 以下为示例
def main(to_pred_dir, result_save_path):
    runpyp = os.path.abspath(__file__)
    modeldirp = os.path.dirname(runpyp)
    modelp = os.path.join(modeldirp,"model.hdf5")
    model = load_model(modelp, custom_objects={
 'angle_error': angle_error})  # 自定义对象
    pred_imgs = os.listdir(to_pred_dir)
    pred_imgsp_lines = [os.path.join(to_pred_dir,p) for p in pred_imgs]
    name, label = display_examples(
        model,
        pred_imgsp_lines,
        num_images=len(pred_imgsp_lines),
        size=(224, 224),
        crop_center=True,
        crop_largest_rect=True,
        preprocess_func=preprocess_input,
    )
    
    df = pd.DataFrame({
 "id":name,"label":label})
    df.to_csv(result_save_path,index=None)
# !!!注意:
# 图片赛题给出的参数为to_pred_dir,是一个文件夹,其图片内容为
# to_pred_dir/to_pred_0.png
# to_pred_dir/to_pred_1.png
# to_pred_dir/......
# 所需要生成的csv文件头为id,label,如下
# image_id,label
# to_pred_0,4
# to_pred_1,76
# to_pred_2,...
if __name__ == "__main__":
    to_pred_dir = sys.argv[1]  # 所需预测的文件夹路径
    result_save_path = sys.argv[2]  # 预测结果保存文件路径
    main(to_pred_dir, result_save_path)

 

参考:Link

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注