diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/img/meinv.jpg b/img/meinv.jpg new file mode 100644 index 0000000..901c080 Binary files /dev/null and b/img/meinv.jpg differ diff --git a/img/trimap/meinv_org_trimap.png b/img/trimap/meinv_org_trimap.png new file mode 100644 index 0000000..bc0dc63 Binary files /dev/null and b/img/trimap/meinv_org_trimap.png differ diff --git a/img/trimap/meinv_resize_trimap.png b/img/trimap/meinv_resize_trimap.png new file mode 100644 index 0000000..b80bf23 Binary files /dev/null and b/img/trimap/meinv_resize_trimap.png differ diff --git a/u_2_net/__init__.py b/u_2_net/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/u_2_net/__pycache__/__init__.cpython-37.pyc b/u_2_net/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..713831e Binary files /dev/null and b/u_2_net/__pycache__/__init__.cpython-37.pyc differ diff --git a/u_2_net/__pycache__/data_loader.cpython-37.pyc b/u_2_net/__pycache__/data_loader.cpython-37.pyc new file mode 100644 index 0000000..7956b61 Binary files /dev/null and b/u_2_net/__pycache__/data_loader.cpython-37.pyc differ diff --git a/u_2_net/data_loader.py b/u_2_net/data_loader.py new file mode 100644 index 0000000..5a54b89 --- /dev/null +++ b/u_2_net/data_loader.py @@ -0,0 +1,259 @@ +# data loader +from __future__ import print_function, division +import glob +import torch +from skimage import io, transform, color +import numpy as np +import random +import math +import matplotlib.pyplot as plt +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from PIL import Image +#==========================dataset load========================== +class RescaleT(object): + + def __init__(self,output_size): + assert isinstance(output_size,(int,tuple)) + self.output_size = output_size + + def __call__(self,sample): + imidx, image, label = sample['imidx'], sample['image'],sample['label'] + + h, w = image.shape[:2] + + if isinstance(self.output_size,int): + if h > w: + new_h, new_w = self.output_size*h/w,self.output_size + else: + new_h, new_w = self.output_size,self.output_size*w/h + else: + new_h, new_w = self.output_size + + new_h, new_w = int(new_h), int(new_w) + + # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] + # img = transform.resize(image,(new_h,new_w),mode='constant') + # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) + + img = transform.resize(image,(self.output_size,self.output_size),mode='constant') + lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) + + return {'imidx':imidx, 'image':img,'label':lbl} + +class Rescale(object): + + def __init__(self,output_size): + assert isinstance(output_size,(int,tuple)) + self.output_size = output_size + + def __call__(self,sample): + imidx, image, label = sample['imidx'], sample['image'],sample['label'] + + h, w = image.shape[:2] + + if isinstance(self.output_size,int): + if h > w: + new_h, new_w = self.output_size*h/w,self.output_size + else: + new_h, new_w = self.output_size,self.output_size*w/h + else: + new_h, new_w = self.output_size + + new_h, new_w = int(new_h), int(new_w) + + # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] + img = transform.resize(image,(new_h,new_w),mode='constant') + lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) + + return {'imidx':imidx, 'image':img,'label':lbl} + +class RandomCrop(object): + + def __init__(self,output_size): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + def __call__(self,sample): + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + h, w = image.shape[:2] + new_h, new_w = self.output_size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[top: top + new_h, left: left + new_w] + label = label[top: top + new_h, left: left + new_w] + + return {'imidx':imidx,'image':image, 'label':label} + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + + imidx, image, label = sample['imidx'], sample['image'], sample['label'] + + tmpImg = np.zeros((image.shape[0],image.shape[1],3)) + tmpLbl = np.zeros(label.shape) + + image = image/np.max(image) + if(np.max(label)<1e-6): + label = label + else: + label = label/np.max(label) + + if image.shape[2]==1: + tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 + else: + tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 + tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 + + tmpLbl[:,:,0] = label[:,:,0] + + # change the r,g,b to b,r,g from [0,255] to [0,1] + #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpLbl = label.transpose((2, 0, 1)) + + return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} + +class ToTensorLab(object): + """Convert ndarrays in sample to Tensors.""" + def __init__(self,flag=0): + self.flag = flag + + def __call__(self, sample): + + imidx, image, label =sample['imidx'], sample['image'], sample['label'] + + tmpLbl = np.zeros(label.shape) + + if(np.max(label)<1e-6): + label = label + else: + label = label/np.max(label) + + # change the color space + if self.flag == 2: # with rgb and Lab colors + tmpImg = np.zeros((image.shape[0],image.shape[1],6)) + tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) + if image.shape[2]==1: + tmpImgt[:,:,0] = image[:,:,0] + tmpImgt[:,:,1] = image[:,:,0] + tmpImgt[:,:,2] = image[:,:,0] + else: + tmpImgt = image + tmpImgtl = color.rgb2lab(tmpImgt) + + # nomalize image to range [0,1] + tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) + tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) + tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) + tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) + tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) + tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) + + # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) + + tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) + tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) + tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) + tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) + tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) + tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) + + elif self.flag == 1: #with Lab color + tmpImg = np.zeros((image.shape[0],image.shape[1],3)) + + if image.shape[2]==1: + tmpImg[:,:,0] = image[:,:,0] + tmpImg[:,:,1] = image[:,:,0] + tmpImg[:,:,2] = image[:,:,0] + else: + tmpImg = image + + tmpImg = color.rgb2lab(tmpImg) + + # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) + + tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) + tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) + tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) + + tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) + tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) + tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) + + else: # with rgb color + tmpImg = np.zeros((image.shape[0],image.shape[1],3)) + image = image/np.max(image) + if image.shape[2]==1: + tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 + else: + tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 + tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 + tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 + + tmpLbl[:,:,0] = label[:,:,0] + + # change the r,g,b to b,r,g from [0,255] to [0,1] + #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpLbl = label.transpose((2, 0, 1)) + + return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} + +class SalObjDataset(Dataset): + def __init__(self,img_name_list,lbl_name_list,transform=None): + # self.root_dir = root_dir + # self.image_name_list = glob.glob(image_dir+'*.png') + # self.label_name_list = glob.glob(label_dir+'*.png') + self.image_name_list = img_name_list + self.label_name_list = lbl_name_list + self.transform = transform + + def __len__(self): + return len(self.image_name_list) + + def __getitem__(self,idx): + + # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) + # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) + + image = io.imread(self.image_name_list[idx]) + imname = self.image_name_list[idx] + imidx = np.array([idx]) + + if(0==len(self.label_name_list)): + label_3 = np.zeros(image.shape) + else: + label_3 = io.imread(self.label_name_list[idx]) + + label = np.zeros(label_3.shape[0:2]) + if(3==len(label_3.shape)): + label = label_3[:,:,0] + elif(2==len(label_3.shape)): + label = label_3 + + if(3==len(image.shape) and 2==len(label.shape)): + label = label[:,:,np.newaxis] + elif(2==len(image.shape) and 2==len(label.shape)): + image = image[:,:,np.newaxis] + label = label[:,:,np.newaxis] + + sample = {'imidx':imidx, 'image':image, 'label':label} + + if self.transform: + sample = self.transform(sample) + + return sample diff --git a/u_2_net/model/__init__.py b/u_2_net/model/__init__.py new file mode 100644 index 0000000..4d8fa27 --- /dev/null +++ b/u_2_net/model/__init__.py @@ -0,0 +1,2 @@ +from .u2net import U2NET +from .u2net import U2NETP diff --git a/u_2_net/model/__pycache__/__init__.cpython-37.pyc b/u_2_net/model/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..dec24c8 Binary files /dev/null and b/u_2_net/model/__pycache__/__init__.cpython-37.pyc differ diff --git a/u_2_net/model/__pycache__/u2net.cpython-37.pyc b/u_2_net/model/__pycache__/u2net.cpython-37.pyc new file mode 100644 index 0000000..716c04f Binary files /dev/null and b/u_2_net/model/__pycache__/u2net.cpython-37.pyc differ diff --git a/u_2_net/model/u2net.py b/u_2_net/model/u2net.py new file mode 100644 index 0000000..ece59e0 --- /dev/null +++ b/u_2_net/model/u2net.py @@ -0,0 +1,526 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +class REBNCONV(nn.Module): + def __init__(self,in_ch=3,out_ch=3,dirate=1): + super(REBNCONV,self).__init__() + + self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self,x): + + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.upsample(src,size=tar.shape[2:],mode='bilinear') + + return src + + +### RSU-7 ### +class RSU7(nn.Module):#UNet07DRES(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU7,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-6 ### +class RSU6(nn.Module):#UNet06DRES(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + + hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-5 ### +class RSU5(nn.Module):#UNet05DRES(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4 ### +class RSU4(nn.Module):#UNet04DRES(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4F ### +class RSU4F(nn.Module):#UNet04FRES(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) + hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) + + return hx1d + hxin + + +##### U^2-Net #### +class U2NET(nn.Module): + + def __init__(self,in_ch=3,out_ch=1): + super(U2NET,self).__init__() + + self.stage1 = RSU7(in_ch,32,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,32,128) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(128,64,256) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(256,128,512) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(512,256,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,256,512) + + # decoder + self.stage5d = RSU4F(1024,256,512) + self.stage4d = RSU4(1024,128,256) + self.stage3d = RSU5(512,64,128) + self.stage2d = RSU6(256,32,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + self.outconv = nn.Conv2d(6,out_ch,1) + + def forward(self,x): + + hx = x + + #stage 1 + hx1 = self.stage1(hx) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6,hx5) + + #-------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) + + + #side output + d1 = self.side1(hx1d) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,d1) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,d1) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,d1) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,d1) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,d1) + + d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) + +### U^2-Net small ### +class U2NETP(nn.Module): + + def __init__(self,in_ch=3,out_ch=1): + super(U2NETP,self).__init__() + + self.stage1 = RSU7(in_ch,16,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,16,64) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(64,16,64) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(64,16,64) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(64,16,64) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(64,16,64) + + # decoder + self.stage5d = RSU4F(128,16,64) + self.stage4d = RSU4(128,16,64) + self.stage3d = RSU5(128,16,64) + self.stage2d = RSU6(128,16,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(64,out_ch,3,padding=1) + self.side4 = nn.Conv2d(64,out_ch,3,padding=1) + self.side5 = nn.Conv2d(64,out_ch,3,padding=1) + self.side6 = nn.Conv2d(64,out_ch,3,padding=1) + + self.outconv = nn.Conv2d(6,out_ch,1) + + def forward(self,x): + + hx = x + + #stage 1 + hx1 = self.stage1(hx) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6,hx5) + + #decoder + hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) + + + #side output + d1 = self.side1(hx1d) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,d1) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,d1) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,d1) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,d1) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,d1) + + d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) diff --git a/u_2_net/my_u2net_test.py b/u_2_net/my_u2net_test.py new file mode 100644 index 0000000..e3d7b1d --- /dev/null +++ b/u_2_net/my_u2net_test.py @@ -0,0 +1,103 @@ +import torch +from torch.autograd import Variable +from torchvision import transforms#, utils +# import torch.optim as optim +import numpy as np +from u_2_net.data_loader import RescaleT +from u_2_net.data_loader import ToTensorLab +from u_2_net.model import U2NET # full size version 173.6 MB +from PIL import Image + + +# normalize the predicted SOD probability map +def normPRED(d): + ma = torch.max(d) + mi = torch.min(d) + dn = (d-mi)/(ma-mi) + return dn + + +def preprocess(image): + label_3 = np.zeros(image.shape) + label = np.zeros(label_3.shape[0:2]) + + if (3 == len(label_3.shape)): + label = label_3[:, :, 0] + elif (2 == len(label_3.shape)): + label = label_3 + if (3 == len(image.shape) and 2 == len(label.shape)): + label = label[:, :, np.newaxis] + elif (2 == len(image.shape) and 2 == len(label.shape)): + image = image[:, :, np.newaxis] + label = label[:, :, np.newaxis] + + transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) + sample = transform({ + 'imidx': np.array([0]), + 'image': image, + 'label': label + }) + + return sample + + +def pre_net(): + # 采用n2net 模型数据 + model_name = 'u2net' + model_dir = 'saved_models/'+ model_name + '/' + model_name + '.pth' + print("...load U2NET---173.6 MB") + net = U2NET(3,1) + # 指定cpu + net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) + if torch.cuda.is_available(): + net.cuda() + net.eval() + return net + + +def pre_test_data(img): + torch.cuda.empty_cache() + sample = preprocess(img) + inputs_test = sample['image'].unsqueeze(0) + inputs_test = inputs_test.type(torch.FloatTensor) + if torch.cuda.is_available(): + inputs_test = Variable(inputs_test.cuda()) + else: + inputs_test = Variable(inputs_test) + return inputs_test + + +def get_im(pred): + predict = pred + predict = predict.squeeze() + predict_np = predict.cpu().data.numpy() + im = Image.fromarray(predict_np*255).convert('RGB') + return im + + +def test_seg_trimap(org,org_trimap,resize_trimap): + # 将原始图片转换成trimap + # org:原始图片 + # org_trimap: + # resize_trimap: 调整尺寸的trimap + image = Image.open(org) + print(image) + img = np.array(image) + net = pre_net() + inputs_test = pre_test_data(img) + d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) + # normalization + pred = d1[:, 0, :, :] + pred = normPRED(pred) + # 将数据转换成图片 + im = get_im(pred) + im.save(org_trimap) + sp = image.size + # 根据原始图片调整尺寸 + imo = im.resize((sp[0], sp[1]), resample=Image.BILINEAR) + imo.save(resize_trimap) + + +if __name__ == "__main__": + test_seg_trimap("..\\img\\meinv.jpg","..\\img\\trimap\\meinv_org_trimap.png","..\\img\\trimap\\meinv_resize_trimap.png") + #pil_wait_blue()