利用卷积模型人像分割
This commit is contained in:
parent
fc5cc89ad5
commit
41e2e8897a
0
__init__.py
Normal file
0
__init__.py
Normal file
BIN
img/meinv.jpg
Normal file
BIN
img/meinv.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 37 KiB |
BIN
img/trimap/meinv_org_trimap.png
Normal file
BIN
img/trimap/meinv_org_trimap.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
BIN
img/trimap/meinv_resize_trimap.png
Normal file
BIN
img/trimap/meinv_resize_trimap.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
0
u_2_net/__init__.py
Normal file
0
u_2_net/__init__.py
Normal file
BIN
u_2_net/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
u_2_net/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
u_2_net/__pycache__/data_loader.cpython-37.pyc
Normal file
BIN
u_2_net/__pycache__/data_loader.cpython-37.pyc
Normal file
Binary file not shown.
259
u_2_net/data_loader.py
Normal file
259
u_2_net/data_loader.py
Normal file
@ -0,0 +1,259 @@
|
||||
# data loader
|
||||
from __future__ import print_function, division
|
||||
import glob
|
||||
import torch
|
||||
from skimage import io, transform, color
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms, utils
|
||||
from PIL import Image
|
||||
#==========================dataset load==========================
|
||||
class RescaleT(object):
|
||||
|
||||
def __init__(self,output_size):
|
||||
assert isinstance(output_size,(int,tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self,sample):
|
||||
imidx, image, label = sample['imidx'], sample['image'],sample['label']
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size,int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size*h/w,self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size,self.output_size*w/h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
# img = transform.resize(image,(new_h,new_w),mode='constant')
|
||||
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
||||
|
||||
img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
|
||||
lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
|
||||
|
||||
return {'imidx':imidx, 'image':img,'label':lbl}
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self,output_size):
|
||||
assert isinstance(output_size,(int,tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self,sample):
|
||||
imidx, image, label = sample['imidx'], sample['image'],sample['label']
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size,int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size*h/w,self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size,self.output_size*w/h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
img = transform.resize(image,(new_h,new_w),mode='constant')
|
||||
lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
||||
|
||||
return {'imidx':imidx, 'image':img,'label':lbl}
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self,output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
if isinstance(output_size, int):
|
||||
self.output_size = (output_size, output_size)
|
||||
else:
|
||||
assert len(output_size) == 2
|
||||
self.output_size = output_size
|
||||
def __call__(self,sample):
|
||||
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
||||
|
||||
h, w = image.shape[:2]
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
top = np.random.randint(0, h - new_h)
|
||||
left = np.random.randint(0, w - new_w)
|
||||
|
||||
image = image[top: top + new_h, left: left + new_w]
|
||||
label = label[top: top + new_h, left: left + new_w]
|
||||
|
||||
return {'imidx':imidx,'image':image, 'label':label}
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label = sample['imidx'], sample['image'], sample['label']
|
||||
|
||||
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
image = image/np.max(image)
|
||||
if(np.max(label)<1e-6):
|
||||
label = label
|
||||
else:
|
||||
label = label/np.max(label)
|
||||
|
||||
if image.shape[2]==1:
|
||||
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
|
||||
else:
|
||||
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
|
||||
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
|
||||
|
||||
tmpLbl[:,:,0] = label[:,:,0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
|
||||
|
||||
class ToTensorLab(object):
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
def __init__(self,flag=0):
|
||||
self.flag = flag
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label =sample['imidx'], sample['image'], sample['label']
|
||||
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
if(np.max(label)<1e-6):
|
||||
label = label
|
||||
else:
|
||||
label = label/np.max(label)
|
||||
|
||||
# change the color space
|
||||
if self.flag == 2: # with rgb and Lab colors
|
||||
tmpImg = np.zeros((image.shape[0],image.shape[1],6))
|
||||
tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
|
||||
if image.shape[2]==1:
|
||||
tmpImgt[:,:,0] = image[:,:,0]
|
||||
tmpImgt[:,:,1] = image[:,:,0]
|
||||
tmpImgt[:,:,2] = image[:,:,0]
|
||||
else:
|
||||
tmpImgt = image
|
||||
tmpImgtl = color.rgb2lab(tmpImgt)
|
||||
|
||||
# nomalize image to range [0,1]
|
||||
tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
|
||||
tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
|
||||
tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
|
||||
tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
|
||||
tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
|
||||
tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
|
||||
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
|
||||
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
|
||||
tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
|
||||
tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
|
||||
tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
|
||||
|
||||
elif self.flag == 1: #with Lab color
|
||||
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
|
||||
|
||||
if image.shape[2]==1:
|
||||
tmpImg[:,:,0] = image[:,:,0]
|
||||
tmpImg[:,:,1] = image[:,:,0]
|
||||
tmpImg[:,:,2] = image[:,:,0]
|
||||
else:
|
||||
tmpImg = image
|
||||
|
||||
tmpImg = color.rgb2lab(tmpImg)
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
|
||||
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
|
||||
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
|
||||
|
||||
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
|
||||
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
|
||||
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
|
||||
|
||||
else: # with rgb color
|
||||
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
|
||||
image = image/np.max(image)
|
||||
if image.shape[2]==1:
|
||||
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
|
||||
else:
|
||||
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
|
||||
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
|
||||
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
|
||||
|
||||
tmpLbl[:,:,0] = label[:,:,0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
|
||||
|
||||
class SalObjDataset(Dataset):
|
||||
def __init__(self,img_name_list,lbl_name_list,transform=None):
|
||||
# self.root_dir = root_dir
|
||||
# self.image_name_list = glob.glob(image_dir+'*.png')
|
||||
# self.label_name_list = glob.glob(label_dir+'*.png')
|
||||
self.image_name_list = img_name_list
|
||||
self.label_name_list = lbl_name_list
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_name_list)
|
||||
|
||||
def __getitem__(self,idx):
|
||||
|
||||
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
|
||||
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
|
||||
|
||||
image = io.imread(self.image_name_list[idx])
|
||||
imname = self.image_name_list[idx]
|
||||
imidx = np.array([idx])
|
||||
|
||||
if(0==len(self.label_name_list)):
|
||||
label_3 = np.zeros(image.shape)
|
||||
else:
|
||||
label_3 = io.imread(self.label_name_list[idx])
|
||||
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
if(3==len(label_3.shape)):
|
||||
label = label_3[:,:,0]
|
||||
elif(2==len(label_3.shape)):
|
||||
label = label_3
|
||||
|
||||
if(3==len(image.shape) and 2==len(label.shape)):
|
||||
label = label[:,:,np.newaxis]
|
||||
elif(2==len(image.shape) and 2==len(label.shape)):
|
||||
image = image[:,:,np.newaxis]
|
||||
label = label[:,:,np.newaxis]
|
||||
|
||||
sample = {'imidx':imidx, 'image':image, 'label':label}
|
||||
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
|
||||
return sample
|
||||
2
u_2_net/model/__init__.py
Normal file
2
u_2_net/model/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .u2net import U2NET
|
||||
from .u2net import U2NETP
|
||||
BIN
u_2_net/model/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
u_2_net/model/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
BIN
u_2_net/model/__pycache__/u2net.cpython-37.pyc
Normal file
BIN
u_2_net/model/__pycache__/u2net.cpython-37.pyc
Normal file
Binary file not shown.
526
u_2_net/model/u2net.py
Normal file
526
u_2_net/model/u2net.py
Normal file
@ -0,0 +1,526 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
import torch.nn.functional as F
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self,in_ch=3,out_ch=3,dirate=1):
|
||||
super(REBNCONV,self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src,tar):
|
||||
|
||||
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):#UNet07DRES(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU7,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
||||
hx6dup = _upsample_like(hx6d,hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module):#UNet06DRES(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module):#UNet05DRES(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module):#UNet04DRES(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
##### U^2-Net ####
|
||||
class U2NET(nn.Module):
|
||||
|
||||
def __init__(self,in_ch=3,out_ch=1):
|
||||
super(U2NET,self).__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch,32,64)
|
||||
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64,32,128)
|
||||
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128,64,256)
|
||||
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256,128,512)
|
||||
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512,256,512)
|
||||
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512,256,512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024,256,512)
|
||||
self.stage4d = RSU4(1024,128,256)
|
||||
self.stage3d = RSU5(512,64,128)
|
||||
self.stage2d = RSU6(256,32,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6,out_ch,1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
#stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
#stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
#stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
#stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#-------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
|
||||
#side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2,d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3,d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4,d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5,d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6,d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
||||
|
||||
### U^2-Net small ###
|
||||
class U2NETP(nn.Module):
|
||||
|
||||
def __init__(self,in_ch=3,out_ch=1):
|
||||
super(U2NETP,self).__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch,16,64)
|
||||
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64,16,64)
|
||||
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(64,16,64)
|
||||
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(64,16,64)
|
||||
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(64,16,64)
|
||||
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(64,16,64)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(128,16,64)
|
||||
self.stage4d = RSU4(128,16,64)
|
||||
self.stage3d = RSU5(128,16,64)
|
||||
self.stage2d = RSU6(128,16,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6,out_ch,1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
#stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
#stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
#stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
#stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#decoder
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
|
||||
#side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2,d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3,d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4,d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5,d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6,d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
||||
103
u_2_net/my_u2net_test.py
Normal file
103
u_2_net/my_u2net_test.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torchvision import transforms#, utils
|
||||
# import torch.optim as optim
|
||||
import numpy as np
|
||||
from u_2_net.data_loader import RescaleT
|
||||
from u_2_net.data_loader import ToTensorLab
|
||||
from u_2_net.model import U2NET # full size version 173.6 MB
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# normalize the predicted SOD probability map
|
||||
def normPRED(d):
|
||||
ma = torch.max(d)
|
||||
mi = torch.min(d)
|
||||
dn = (d-mi)/(ma-mi)
|
||||
return dn
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
label_3 = np.zeros(image.shape)
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
|
||||
if (3 == len(label_3.shape)):
|
||||
label = label_3[:, :, 0]
|
||||
elif (2 == len(label_3.shape)):
|
||||
label = label_3
|
||||
if (3 == len(image.shape) and 2 == len(label.shape)):
|
||||
label = label[:, :, np.newaxis]
|
||||
elif (2 == len(image.shape) and 2 == len(label.shape)):
|
||||
image = image[:, :, np.newaxis]
|
||||
label = label[:, :, np.newaxis]
|
||||
|
||||
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
|
||||
sample = transform({
|
||||
'imidx': np.array([0]),
|
||||
'image': image,
|
||||
'label': label
|
||||
})
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def pre_net():
|
||||
# 采用n2net 模型数据
|
||||
model_name = 'u2net'
|
||||
model_dir = 'saved_models/'+ model_name + '/' + model_name + '.pth'
|
||||
print("...load U2NET---173.6 MB")
|
||||
net = U2NET(3,1)
|
||||
# 指定cpu
|
||||
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
|
||||
def pre_test_data(img):
|
||||
torch.cuda.empty_cache()
|
||||
sample = preprocess(img)
|
||||
inputs_test = sample['image'].unsqueeze(0)
|
||||
inputs_test = inputs_test.type(torch.FloatTensor)
|
||||
if torch.cuda.is_available():
|
||||
inputs_test = Variable(inputs_test.cuda())
|
||||
else:
|
||||
inputs_test = Variable(inputs_test)
|
||||
return inputs_test
|
||||
|
||||
|
||||
def get_im(pred):
|
||||
predict = pred
|
||||
predict = predict.squeeze()
|
||||
predict_np = predict.cpu().data.numpy()
|
||||
im = Image.fromarray(predict_np*255).convert('RGB')
|
||||
return im
|
||||
|
||||
|
||||
def test_seg_trimap(org,org_trimap,resize_trimap):
|
||||
# 将原始图片转换成trimap
|
||||
# org:原始图片
|
||||
# org_trimap:
|
||||
# resize_trimap: 调整尺寸的trimap
|
||||
image = Image.open(org)
|
||||
print(image)
|
||||
img = np.array(image)
|
||||
net = pre_net()
|
||||
inputs_test = pre_test_data(img)
|
||||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
|
||||
# normalization
|
||||
pred = d1[:, 0, :, :]
|
||||
pred = normPRED(pred)
|
||||
# 将数据转换成图片
|
||||
im = get_im(pred)
|
||||
im.save(org_trimap)
|
||||
sp = image.size
|
||||
# 根据原始图片调整尺寸
|
||||
imo = im.resize((sp[0], sp[1]), resample=Image.BILINEAR)
|
||||
imo.save(resize_trimap)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_seg_trimap("..\\img\\meinv.jpg","..\\img\\trimap\\meinv_org_trimap.png","..\\img\\trimap\\meinv_resize_trimap.png")
|
||||
#pil_wait_blue()
|
||||
Loading…
x
Reference in New Issue
Block a user