本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
本文已参与「新人创作礼」活动,一起开启掘金创作之路。
这个跟前两个一样,显着性检测貌似无法解决我的问题,准备换个方向解决我的问题了,虽然我的目的没达到,但是这个的效果确实还行的,有需要的可以好好调整一下。 使用链接: github.com/lartpang/MI…
原图 标签 预测结果 评价结果:
acc: 0.9055214352077908 acc_cls: 0.8682510382904347 iou: [0.88870665 0.61525859] miou: 0.7519826202767053 fwavacc: 0.8376228680494308 class_accuracy: 0.7143424443012731 class_recall: 0.7998325458213021 accuracy: 0.9007926079195722 f1_score: 0.7546741156614227
注意这个是我默认参数跑的,iou上来就是0.6以上了,感觉效果不错,不过这个跑的有点慢。
1.数据准备数据准备很简单,就是普通的存放方式 一级目录 二级目录 这里面的文件夹名字最好和我一样,代码里是通过这个名字拼凑路径的,另外,图像和标签的名字保持一样就行。 2.数据导入 这里要改的就是测试时的数据导入,训练的数据导入包含了测试和验证,我把训练时候的验证去掉了
# -*- coding: utf-8 -*- # @Time : 2020/7/22 # @Author : Lart Pang # @Email : [email protected] # @File : dataloader.py # @Project : code # @GitHub : https://github.com/lartpang import os import random from functools import partial import torch from PIL import Image from prefetch_generator import BackgroundGenerator from torch.nn.functional import interpolate from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchvision import transforms from config import arg_config from utils.joint_transforms import Compose, JointResize, RandomHorizontallyFlip, RandomRotate from utils.misc import construct_print def _get_suffix(path_list): ext_list = list(set([os.path.splitext(p)[1] for p in path_list])) if len(ext_list) != 1: if ".png" in ext_list: ext = ".png" elif ".jpg" in ext_list: ext = ".jpg" elif ".bmp" in ext_list: ext = ".bmp" else: raise NotImplementedError construct_print(f"数据文件夹中包含多种扩展名,这里仅使用{ext}") else: ext = ext_list[0] return ext def _make_dataset(root): img_path = os.path.join(root, "Image") mask_path = os.path.join(root, "Mask") img_list = os.listdir(img_path) mask_list = os.listdir(mask_path) img_suffix = _get_suffix(img_list) mask_suffix = _get_suffix(mask_list) img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)] return [ ( os.path.join(img_path, img_name + img_suffix), os.path.join(mask_path, img_name + mask_suffix), ) for img_name in img_list ] def _make_dataset2(root): img_path = os.path.join(root, "Image") # mask_path = os.path.join(root, "Mask") img_list = os.listdir(img_path) # mask_list = os.listdir(mask_path) img_suffix = _get_suffix(img_list) # mask_suffix = _get_suffix(mask_list) # img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)] return [ ( os.path.join(img_path, img_name), # os.path.join(mask_path, img_name + mask_suffix), ) for img_name in img_list ] def _read_list_from_file(list_filepath): img_list = [] with open(list_filepath, mode="r", encoding="utf-8") as openedfile: line = openedfile.readline() while line: img_list.append(line.split()[0]) line = openedfile.readline() return img_list def _make_dataset_from_list(list_filepath, prefix=(".png", ".png")): img_list = _read_list_from_file(list_filepath) return [ ( os.path.join( os.path.join(os.path.dirname(img_path), "Image"), #路径拼凑的地方 os.path.basename(img_path) + prefix[0], ), os.path.join( os.path.join(os.path.dirname(img_path), "Mask"), #路径拼凑的地方 os.path.basename(img_path) + prefix[1], ), ) for img_path in img_list ] def _make_dataset_from_list2(list_filepath, prefix=(".png", ".png")): #用于测试数据导入,不需要标签,测试还要标签是很多时候不遇到的情况 img_list = _read_list_from_file(list_filepath) return [ ( os.path.join( os.path.join(os.path.dirname(img_path), "Image"), #路径拼凑的地方 os.path.basename(img_path) + prefix[0], ), # os.path.join( # os.path.join(os.path.dirname(img_path), "Mask"), # os.path.basename(img_path) + prefix[1], # ), ) for img_path in img_list ] class ImageFolder(Dataset): def __init__(self, root, in_size, training, prefix, use_bigt=False): self.training = training self.use_bigt = use_bigt if os.path.isdir(root): construct_print(f"{root} is an image folder, we will test on it.") self.imgs = _make_dataset(root) elif os.path.isfile(root): construct_print( f"{root} is a list of images, we will use these paths to read the " f"corresponding image" ) self.imgs = _make_dataset_from_list(root, prefix=prefix) else: raise NotImplementedError if self.training: self.joint_transform = Compose( [JointResize(in_size), RandomHorizontallyFlip(), RandomRotate(10)] ) img_transform = [transforms.ColorJitter(0.1, 0.1, 0.1)] self.mask_transform = transforms.ToTensor() else: # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值 img_transform = [ transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR), ] self.img_transform = transforms.Compose( [ *img_transform, transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]), ] ) def __getitem__(self, index): img_path, mask_path = self.imgs[index] img_name = os.path.splitext(os.path.basename(img_path))[0] img = Image.open(img_path).convert("RGB") if self.training: mask = Image.open(mask_path).convert("L") img, mask = self.joint_transform(img, mask) img = self.img_transform(img) mask = self.mask_transform(mask) if self.use_bigt: mask = mask.ge(0.5).float() # 二值化 return img, mask, img_name else: # todo: When evaluating, the mask path may not exist. But our code defaults to its existence, which makes # it impossible to use dataloader to generate a prediction without a mask path. img = self.img_transform(img) # img = img / 255.0 return img, mask_path, img_name def __len__(self): return len(self.imgs) class ImageFolder2(Dataset): #增加的测试数据导入 def __init__(self, root, in_size, training, prefix, use_bigt=False): self.training = training self.use_bigt = use_bigt if os.path.isdir(root): construct_print(f"{root} is an image folder, we will test on it.") self.imgs = _make_dataset2(root) elif os.path.isfile(root): construct_print( f"{root} is a list of images, we will use these paths to read the " f"corresponding image" ) self.imgs = _make_dataset_from_list2(root, prefix=prefix) else: raise NotImplementedError # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值 img_transform = [ transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR), ] self.img_transform = transforms.Compose( [ *img_transform, transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]), ] ) def __getitem__(self, index): # print(self.imgs[index][0]) img_path = self.imgs[index][0] img_name = os.path.splitext(os.path.basename(img_path))[0] img = Image.open(img_path).convert("RGB") img = self.img_transform(img) return img, img_name def __len__(self): return len(self.imgs) class DataLoaderX(DataLoader): def __iter__(self): return BackgroundGenerator(super(DataLoaderX, self).__iter__()) def _collate_fn(batch, size_list): size = random.choice(size_list) img, mask, image_name = [list(item) for item in zip(*batch)] img = torch.stack(img, dim=0) img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False) mask = torch.stack(mask, dim=0) mask = interpolate(mask, size=(size, size), mode="nearest") return img, mask, image_name def _mask_loader(dataset, shuffle, drop_last, size_list): assert float(torch.__version__[:3]) >= 1.2, ( "If you want to use the pytorch < 1.2, you need to " "comment out the line `collate_fn=...` when you set the `size_list` to `None`." ) return DataLoaderX( dataset=dataset, collate_fn=partial(_collate_fn, size_list=size_list) if size_list else None, batch_size=arg_config["batch_size"], num_workers=arg_config["num_workers"], shuffle=shuffle, drop_last=drop_last, pin_memory=True, ) def create_loader(data_path, training, size_list=None, prefix=(".jpg", ".png"), get_length=False): if training: construct_print(f"Training on: {data_path}") imageset = ImageFolder( data_path, in_size=arg_config["input_size"], prefix=prefix, use_bigt=arg_config["use_bigt"], training=True, ) loader = _mask_loader(imageset, shuffle=True, drop_last=True, size_list=size_list) else: construct_print(f"Testing on: {data_path}") imageset = ImageFolder2( data_path, in_size=arg_config["input_size"], prefix=prefix, training=False, ) loader = _mask_loader(imageset, shuffle=False, drop_last=False, size_list=None) if get_length: length_of_dataset = len(imageset) return loader, length_of_dataset else: return loader if __name__ == "__main__": loader = create_loader( data_path=arg_config["rgb_data"]["tr_data_path"], training=True, get_length=False, size_list=arg_config["size_list"], ) for idx, train_data in enumerate(loader): train_inputs, train_masks, *train_other_data = train_data print(f"" f"batch: {idx} ", train_inputs.size(), train_masks.size())
3.训练这个源码主要是用过配置文件控制的下面先说下配置文件 config.py
import os __all__ = ["proj_root", "arg_config"] from collections import OrderedDict proj_root = os.path.dirname(__file__) datasets_root = "./Dataset/" #原作者的路径 # ecssd_path = os.path.join(datasets_root, "Saliency/RGBSOD", "ECSSD") # dutomron_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUT-OMRON") # hkuis_path = os.path.join(datasets_root, "Saliency/RGBSOD", "HKU-IS") # pascals_path = os.path.join(datasets_root, "Saliency/RGBSOD", "PASCAL-S") # soc_path = os.path.join(datasets_root, "Saliency/RGBSOD", "SOC/Test") # dutstr_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Train") # dutste_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Test") #本人测试使用的路径 # dutstr_path = os.path.join(datasets_root, "ECSSD/Train") ecssdte_path = os.path.join(datasets_root, "ECSSD/Test") modelte_path = os.path.join(datasets_root, "TEST") rivertr_path = os.path.join(datasets_root, "RIVER/Train") riverte_path = os.path.join(datasets_root, "RIVER/Test") buildtr_path = os.path.join(datasets_root, "BUILD/Train") buildte_path = os.path.join(datasets_root, "BUILD/Test") arg_config = { "model": "MINet_VGG16", # 实际使用的模型,需要在`network/__init__.py`中导入 "info": "", # 关于本次实验的额外信息说明,这个会附加到本次试验的exp_name的结尾,如果为空,则不会附加内容。 "use_amp": False, # 是否使用amp加速训练 "resume_mode": "inference", # the mode for resume parameters: ['train', 'test', 'inference', ''] #这里注意了,由于我改过的缘故,训练选'',测试选inference "use_aux_loss": False, # 是否使用辅助损失, 这个可以设置多个损失函数,需要在solver.py文件里的self.loss_funcs参数里增加 "save_pre": True, # 是否保留最终的预测结果 "epoch_num": 60, # 训练周期, 0: directly test model "lr": 0.001, # 微调时缩小100倍 "xlsx_name": "result.xlsx", # the name of the record file # 数据集设置 "rgb_data": { "tr_data_path": buildtr_path, #训练路径 "te_data_list": OrderedDict( { # "pascal-s": pascals_path, # "ecssd": ecssdte_path, # "hku-is": hkuis_path, # "duts": dutste_path, # "dut-omron": dutomron_path, # "soc": soc_path, # "river": riverte_path, "modelte": buildte_path, #测试路径 }, ), }, # 训练过程中的监控信息 "tb_update": 50, # >0 则使用tensorboard "print_freq": 50, # >0, 保存迭代过程中的信息 # img_prefix, gt_prefix,用在使用索引文件的时候的对应的扩展名 "prefix": (".jpg", ".png"), # if you dont use the multi-scale training, you can set 'size_list': None # "size_list": [224, 256, 288, 320, 352], "size_list": None, # 不使用多尺度训练 "reduction": "mean", # 损失处理的方式,可选“mean”和“sum” # 优化器与学习率衰减 "optim": "adam", # 自定义部分的学习率 "weight_decay": 5e-4, # 微调时设置为0.0001 "momentum": 0.9, "nesterov": False, "sche_usebatch": False, "lr_type": "poly", "warmup_epoch": 1, # depond on the special lr_type, only lr_type has 'warmup', when set it to 1, it means no warmup. "lr_decay": 0.9, # poly "use_bigt": True, # 训练时是否对真值二值化(阈值为0.5) "batch_size": 4, # 要是继续训练, 最好使用相同的batchsize "num_workers": 0, # 不要太大, 不然运行多个程序同时训练的时候, 会造成数据读入速度受影响 "input_size": 512, #图像大小,里面会有resize 大小,和原本图像不一致会自动帮你resize }
main.py 这个文件我加了infercence的选项,这个和配置文件里对应
import shutil from datetime import datetime from config import arg_config, proj_root from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed from utils.solver import Solver construct_print(f"{datetime.now()}: Initializing...") construct_print(f"Project Root: {proj_root}") init_start = datetime.now() exp_name = construct_exp_name(arg_config) path_config = construct_path( proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"], ) pre_mkdir(path_config) set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] != None) solver = Solver(exp_name, arg_config, path_config) construct_print(f"Total initialization time:{datetime.now() - init_start}") shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"]) shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"]) construct_print(f"{datetime.now()}: Start...") if arg_config["resume_mode"] == "test": solver.test() elif arg_config["resume_mode"] == "inference": #增加了这里 solver.inference() else: solver.train() construct_print(f"{datetime.now()}: End...")
solver.py
import os from pprint import pprint import numpy as np import torch from PIL import Image from torchvision import transforms from tqdm import tqdm import network as network_lib from loss.CEL import CEL from loss.focal_loss import FocalLoss #下面这些loss函数都是我加的后面会打包一起给 from loss.dice_loss import DiceLoss from loss.iou_loss import IoULoss from utils.dataloader import create_loader from utils.metric import cal_maxf, cal_pr_mae_meanf from utils.misc import ( AvgMeter, construct_print, write_data_to_file, ) from utils.pipeline_ops import ( get_total_loss, make_optimizer, make_scheduler, resume_checkpoint, save_checkpoint, ) from utils.recorder import TBRecorder, Timer, XLSXRecoder class Solver: def __init__(self, exp_name: str, arg_dict: dict, path_dict: dict): super(Solver, self).__init__() self.exp_name = exp_name self.arg_dict = arg_dict self.path_dict = path_dict self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.to_pil = transforms.ToPILImage() self.tr_data_path = self.arg_dict["rgb_data"]["tr_data_path"] self.te_data_list = self.arg_dict["rgb_data"]["te_data_list"] self.save_path = self.path_dict["save"] self.save_pre = self.arg_dict["save_pre"] if self.arg_dict["tb_update"] > 0: self.tb_recorder = TBRecorder(tb_path=self.path_dict["tb"]) if self.arg_dict["xlsx_name"]: self.xlsx_recorder = XLSXRecoder(xlsx_path=self.path_dict["xlsx"]) # 依赖与前面属性的属性 self.tr_loader = create_loader( data_path=self.tr_data_path, training=True, size_list=self.arg_dict["size_list"], prefix=self.arg_dict["prefix"], get_length=False, ) self.end_epoch = self.arg_dict["epoch_num"] self.iter_num = self.end_epoch * len(self.tr_loader) if hasattr(network_lib, self.arg_dict["model"]): self.net = getattr(network_lib, self.arg_dict["model"])().to(self.dev) else: raise AttributeError pprint(self.arg_dict) if self.arg_dict["resume_mode"] == "test": # resume model only to test model. # self.start_epoch is useless resume_checkpoint( model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet", ) return #因为新加了inference,所以这里也对应加了 if self.arg_dict["resume_mode"] == "inference": # resume model only to test model. # self.start_epoch is useless resume_checkpoint( model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet", ) return #可以多个loss,记得把config.py文件对应位置的设置改为True self.loss_funcs = [ # torch.nn.BCEWithLogitsLoss(reduction=self.arg_dict["reduction"]).to(self.dev) # FocalLoss() IoULoss() ] if self.arg_dict["use_aux_loss"]: self.loss_funcs.append(CEL().to(self.dev)) self.opti = make_optimizer( model=self.net, optimizer_type=self.arg_dict["optim"], optimizer_info=dict( lr=self.arg_dict["lr"], momentum=self.arg_dict["momentum"], weight_decay=self.arg_dict["weight_decay"], nesterov=self.arg_dict["nesterov"], ), ) self.sche = make_scheduler( optimizer=self.opti, total_num=self.iter_num if self.arg_dict["sche_usebatch"] else self.end_epoch, scheduler_type=self.arg_dict["lr_type"], scheduler_info=dict( lr_decay=self.arg_dict["lr_decay"], warmup_epoch=self.arg_dict["warmup_epoch"] ), ) # AMP if self.arg_dict["use_amp"]: construct_print("Now, we will use the amp to accelerate training!") from apex import amp self.amp = amp self.net, self.opti = self.amp.initialize(self.net, self.opti, opt_level="O1") else: self.amp = None if self.arg_dict["resume_mode"] == "train": # resume model to train the model self.start_epoch = resume_checkpoint( model=self.net, optimizer=self.opti, scheduler=self.sche, amp=self.amp, exp_name=self.exp_name, load_path=self.path_dict["final_full_net"], mode="all", ) else: # only train a new model. self.start_epoch = 0 def train(self): for curr_epoch in range(self.start_epoch, self.end_epoch): train_loss_record = AvgMeter() self._train_per_epoch(curr_epoch, train_loss_record) # 根据周期修改学习率 if not self.arg_dict["sche_usebatch"]: self.sche.step() # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数 save_checkpoint( model=self.net, optimizer=self.opti, scheduler=self.sche, amp=self.amp, exp_name=self.exp_name, current_epoch=curr_epoch + 1, full_net_path=self.path_dict["final_full_net"], state_net_path=self.path_dict["final_state_net"], ) # 保存参数 #这里被我注释了,如果要用,需要把dataloader.py 里面的create_loader函数中的ImageFolder2换成ImageFolder # if self.arg_dict["use_amp"]: # # https://github.com/NVIDIA/apex/issues/567 # with self.amp.disable_casts(): # construct_print("When evaluating, we wish to evaluate in pure fp32.") # self.test() # else: # self.test() @Timer def _train_per_epoch(self, curr_epoch, train_loss_record): for curr_iter_in_epoch, train_data in enumerate(self.tr_loader): num_iter_per_epoch = len(self.tr_loader) curr_iter = curr_epoch * num_iter_per_epoch + curr_iter_in_epoch self.opti.zero_grad() train_inputs, train_masks, _ = train_data train_inputs = train_inputs.to(self.dev, non_blocking=True) train_masks = train_masks.to(self.dev, non_blocking=True) train_preds = self.net(train_inputs) train_loss, loss_item_list = get_total_loss(train_preds, train_masks, self.loss_funcs) if self.amp: with self.amp.scale_loss(train_loss, self.opti) as scaled_loss: scaled_loss.backward() else: train_loss.backward() self.opti.step() if self.arg_dict["sche_usebatch"]: self.sche.step() # 仅在累计的时候使用item()获取数据 train_iter_loss = train_loss.item() train_batch_size = train_inputs.size(0) train_loss_record.update(train_iter_loss, train_batch_size) # 显示tensorboard if ( self.arg_dict["tb_update"] > 0 and (curr_iter + 1) % self.arg_dict["tb_update"] == 0 ): self.tb_recorder.record_curve("trloss_avg", train_loss_record.avg, curr_iter) self.tb_recorder.record_curve("trloss_iter", train_iter_loss, curr_iter) self.tb_recorder.record_curve("lr", self.opti.param_groups, curr_iter) self.tb_recorder.record_image("trmasks", train_masks, curr_iter) self.tb_recorder.record_image("trsodout", train_preds.sigmoid(), curr_iter) self.tb_recorder.record_image("trsodin", train_inputs, curr_iter) # 记录每一次迭代的数据 if ( self.arg_dict["print_freq"] > 0 and (curr_iter + 1) % self.arg_dict["print_freq"] == 0 ): lr_str = ",".join( [f"{param_groups['lr']:.7f}" for param_groups in self.opti.param_groups] ) log = ( f"{curr_iter_in_epoch}:{num_iter_per_epoch}/" f"{curr_iter}:{self.iter_num}/" f"{curr_epoch}:{self.end_epoch} " f"{self.exp_name} " f"Lr:{lr_str} " f"M:{train_loss_record.avg:.5f} C:{train_iter_loss:.5f} " f"{loss_item_list}" ) print(log) write_data_to_file(log, self.path_dict["tr_log"]) def test(self): self.net.eval() total_results = {} for data_name, data_path in self.te_data_list.items(): construct_print(f"Testing with testset: {data_name}") self.te_loader = create_loader( data_path=data_path, training=False, prefix=self.arg_dict["prefix"], get_length=False, ) self.save_path = os.path.join(self.path_dict["save"], data_name) if not os.path.exists(self.save_path): construct_print(f"{self.save_path} do not exist. Let's create it.") os.makedirs(self.save_path) results = self._test_process(save_pre=self.save_pre) msg = f"Results on the testset({data_name}:'{data_path}'): {results}" construct_print(msg) write_data_to_file(msg, self.path_dict["te_log"]) total_results[data_name] = results self.net.train() if self.arg_dict["xlsx_name"]: # save result into xlsx file. self.xlsx_recorder.write_xlsx(self.exp_name, total_results) def _test_process(self, save_pre): loader = self.te_loader pres = [AvgMeter() for _ in range(256)] recs = [AvgMeter() for _ in range(256)] meanfs = AvgMeter() maes = AvgMeter() tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False) for test_batch_id, test_data in tqdm_iter: tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}") with torch.no_grad(): in_imgs, in_mask_paths, in_names = test_data in_imgs = in_imgs.to(self.dev, non_blocking=True) outputs = self.net(in_imgs) outputs_np = outputs.sigmoid().cpu().detach() for item_id, out_item in enumerate(outputs_np): gimg_path = os.path.join(in_mask_paths[item_id]) gt_img = Image.open(gimg_path).convert("L") out_img = self.to_pil(out_item).resize(gt_img.size, resample=Image.NEAREST) if save_pre: oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png") out_img.save(oimg_path) gt_img = np.array(gt_img) out_img = np.array(out_img) ps, rs, mae, meanf = cal_pr_mae_meanf(out_img, gt_img) for pidx, pdata in enumerate(zip(ps, rs)): p, r = pdata pres[pidx].update(p) recs[pidx].update(r) maes.update(mae) meanfs.update(meanf) maxf = cal_maxf([pre.avg for pre in pres], [rec.avg for rec in recs]) results = {"MAXF": maxf, "MEANF": meanfs.avg, "MAE": maes.avg} return results #这里是我加的 def inference(self): self.net.eval() total_results = {} for data_name, data_path in self.te_data_list.items(): construct_print(f"Testing with testset: {data_name}") self.te_loader = create_loader( data_path=data_path, training=False, prefix=self.arg_dict["prefix"], get_length=False, ) self.save_path = os.path.join(self.path_dict["save"], data_name) if not os.path.exists(self.save_path): construct_print(f"{self.save_path} do not exist. Let's create it.") os.makedirs(self.save_path) self._inference_process(save_pre=self.save_pre) # msg = f"Results on the testset({data_name}:'{data_path}'): {results}" # construct_print(msg) # write_data_to_file(msg, self.path_dict["te_log"]) # total_results[data_name] = results # self.net.train() # if self.arg_dict["xlsx_name"]: # # save result into xlsx file. # self.xlsx_recorder.write_xlsx(self.exp_name, total_results) def _inference_process(self, save_pre): loader = self.te_loader tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False) for test_batch_id, test_data in tqdm_iter: tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}") with torch.no_grad(): in_imgs, in_names= test_data # print(in_imgs.shape) in_imgs = in_imgs.to(self.dev, non_blocking=True) outputs = self.net(in_imgs) outputs_np = outputs.sigmoid().cpu().detach() for item_id, out_item in enumerate(outputs_np): out_img = self.to_pil(out_item).resize((256,256), resample=Image.NEAREST) if save_pre: oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png") out_img.save(oimg_path)
pipeline_ops.py 这里改了loss获取的函数get_total_loss,自己的loss会报一个错,这幺改了以后能用
import os import torch import torch.nn as nn import torch.optim.optimizer as optim import torch.optim.lr_scheduler as sche import numpy as np from torch.optim import Adam, SGD from utils.misc import construct_print def get_total_loss( train_preds: torch.Tensor, train_masks: torch.Tensor, loss_funcs: list ) -> (float, list): """ return the sum of the list of loss functions with train_preds and train_masks Args: train_preds (torch.Tensor): predictions train_masks (torch.Tensor): masks loss_funcs (list): the list of loss functions Returns: the sum of all losses and the list of result strings """ loss_list = [] loss_item_list = [] assert len(loss_funcs) != 0, "请指定损失函数`loss_funcs`" for loss in loss_funcs: loss_out = loss(train_preds, train_masks) try: loss_list.append(loss_out) loss_item_list.append(f"{loss_out.item():.5f}") except: loss_list.append(loss_out) loss_item_list.append(f"{loss_out:.5f}") train_loss = sum(loss_list) return train_loss, loss_item_list def save_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, amp=None, exp_name: str = "", current_epoch: int = 1, full_net_path: str = "", state_net_path: str = "", ): """ 保存完整参数模型(大)和状态参数模型(小) Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object amp (): apex.amp exp_name (str): exp_name current_epoch (int): in the epoch, model **will** be trained full_net_path (str): the path for saving the full model parameters state_net_path (str): the path for saving the state dict. """ state_dict = { "arch": exp_name, "epoch": current_epoch, "net_state": model.state_dict(), "opti_state": optimizer.state_dict(), "sche_state": scheduler.state_dict(), "amp_state": amp.state_dict() if amp else None, } torch.save(state_dict, full_net_path) torch.save(model.state_dict(), state_net_path) def resume_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, amp=None, exp_name: str = "", load_path: str = "", mode: str = "all", ): """ 从保存节点恢复模型 Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object amp (): apex.amp exp_name (str): exp_name load_path (str): 模型存放路径 mode (str): 选择哪种模型恢复模式: - 'all': 回复完整模型,包括训练中的的参数; - 'onlynet': 仅恢复模型权重参数 Returns mode: 'all' start_epoch; 'onlynet' None """ if os.path.exists(load_path) and os.path.isfile(load_path): construct_print(f"Loading checkpoint '{load_path}'") checkpoint = torch.load(load_path) if mode == "all": if exp_name and exp_name != checkpoint["arch"]: # 如果给定了exp_name,那幺就必须匹配对应的checkpoint["arch"],否则不作要求 raise Exception(f"We can not match {exp_name} with {load_path}.") start_epoch = checkpoint["epoch"] if hasattr(model, "module"): model.module.load_state_dict(checkpoint["net_state"]) else: model.load_state_dict(checkpoint["net_state"]) optimizer.load_state_dict(checkpoint["opti_state"]) scheduler.load_state_dict(checkpoint["sche_state"]) if checkpoint.get("amp_state", None): if amp: amp.load_state_dict(checkpoint["amp_state"]) else: construct_print("You are not using amp.") else: construct_print("The state_dict of amp is None.") construct_print( f"Loaded '{load_path}' " f"(will train at epoch" f" {checkpoint['epoch']})" ) return start_epoch elif mode == "onlynet": if hasattr(model, "module"): model.module.load_state_dict(checkpoint) else: model.load_state_dict(checkpoint) construct_print( f"Loaded checkpoint '{load_path}' " f"(only has the model's weight params)" ) else: raise NotImplementedError else: raise Exception(f"{load_path}路径不正常,请检查") def make_scheduler( optimizer: optim.Optimizer, total_num: int, scheduler_type: str, scheduler_info: dict ) -> sche._LRScheduler: def get_lr_coefficient(curr_epoch): nonlocal total_num # curr_epoch start from 0 # total_num = iter_num if args["sche_usebatch"] else end_epoch if scheduler_type == "poly": coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"]) elif scheduler_type == "poly_warmup": turning_epoch = scheduler_info["warmup_epoch"] if curr_epoch < turning_epoch: # 0,1,2,...,turning_epoch-1 coefficient = 1 / turning_epoch * (1 + curr_epoch) else: # turning_epoch,...,end_epoch curr_epoch -= turning_epoch - 1 total_num -= turning_epoch - 1 coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"]) elif scheduler_type == "cosine_warmup": turning_epoch = scheduler_info["warmup_epoch"] if curr_epoch < turning_epoch: # 0,1,2,...,turning_epoch-1 coefficient = 1 / turning_epoch * (1 + curr_epoch) else: # turning_epoch,...,end_epoch curr_epoch -= turning_epoch - 1 total_num -= turning_epoch - 1 coefficient = (1 + np.cos(np.pi * curr_epoch / total_num)) / 2 elif scheduler_type == "f3_sche": coefficient = 1 - abs((curr_epoch + 1) / (total_num + 1) * 2 - 1) else: raise NotImplementedError return coefficient scheduler = sche.LambdaLR(optimizer, lr_lambda=get_lr_coefficient) return scheduler def make_optimizer(model: nn.Module, optimizer_type: str, optimizer_info: dict) -> optim.Optimizer: if optimizer_type == "sgd_trick": # https://github.com/implus/PytorchInsight/blob/master/classification/imagenet_tricks.py params = [ { "params": [ p for name, p in model.named_parameters() if ("bias" in name or "bn" in name) ], "weight_decay": 0, }, { "params": [ p for name, p in model.named_parameters() if ("bias" not in name and "bn" not in name) ] }, ] optimizer = SGD( params, lr=optimizer_info["lr"], momentum=optimizer_info["momentum"], weight_decay=optimizer_info["weight_decay"], nesterov=optimizer_info["nesterov"], ) elif optimizer_type == "sgd_r3": params = [ # 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络 # 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达 # 到减少模型过拟合的效果。 { "params": [ param for name, param in model.named_parameters() if name[-4:] == "bias" ], "lr": 2 * optimizer_info["lr"], }, { "params": [ param for name, param in model.named_parameters() if name[-4:] != "bias" ], "lr": optimizer_info["lr"], "weight_decay": optimizer_info["weight_decay"], }, ] optimizer = SGD(params, momentum=optimizer_info["momentum"]) elif optimizer_type == "sgd_all": optimizer = SGD( model.parameters(), lr=optimizer_info["lr"], weight_decay=optimizer_info["weight_decay"], momentum=optimizer_info["momentum"], ) elif optimizer_type == "adam": optimizer = Adam( model.parameters(), lr=optimizer_info["lr"], betas=(0.9, 0.999), eps=1e-8, weight_decay=optimizer_info["weight_decay"], ) elif optimizer_type == "f3_trick": backbone, head = [], [] for name, params_tensor in model.named_parameters(): if name.startswith("div_2"): pass elif name.startswith("div"): backbone.append(params_tensor) else: head.append(params_tensor) params = [ {"params": backbone, "lr": 0.1 * optimizer_info["lr"]}, {"params": head, "lr": optimizer_info["lr"]}, ] optimizer = SGD( params=params, momentum=optimizer_info["momentum"], weight_decay=optimizer_info["weight_decay"], nesterov=optimizer_info["nesterov"], ) else: raise NotImplementedError print("optimizer = ", optimizer) return optimizer if __name__ == "__main__": a = torch.rand((3, 3)).bool() print(isinstance(a, torch.FloatTensor), a.type())
4.预测训练完以后自动生成一个ouput文件夹,当你config.py文件都设置好以后这个会自动生成配置很多东西,记得测试要设置”resume_mode”: “inference”,结果存储的位置也在output里的pre文件夹中 下面是我传到百度网盘的参考,数据前面的博客提供了,这里面没放数据 链接: pan.baidu.com/s/1n1gfGEIm… 提取码:7477 复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V5的分享
Be First to Comment