8
Select, Attend, and Transfer: Light, Learnable Skip Connections Saeid Asgari Taghanaki 1,2 , Aicha Bentaieb 1 , Anmol Sharma 1 , S. Kevin Zhou 2 , Yefeng Zheng 2 , Bogdan Georgescu 2 , Puneet Sharma 2 , Zhoubing Xu 2 , Dorin Comaniciu 2 , and Ghassan Hamarneh 1 1 Medical Image Analysis Lab, School of Computing Science, Simon Fraser University, Canada 2 Medical Imaging Technologies, Siemens Healthineers, Princeton, NJ, USA [email protected] Abstract. Skip connections in deep networks have improved both seg- mentation and classification performance by facilitating the training of deeper network architectures and reducing the risks for vanishing gra- dients. The skip connections equip encoder-decoder like networks with richer feature representations, but at the cost of higher memory usage, computation, and possibly resulting in transferring non-discriminative feature maps. In this paper, we focus on improving the skip connections used in segmentation networks. We propose light, learnable skip connec- tions which learn to first select the most discriminative channels, and then aggregate the selected ones as single channel attending to the most discriminative regions of input. We evaluate the proposed method on 3 different 2D and volumetric datasets and demonstrate that the proposed skip connections can outperform the traditional heavy skip connections of 4 different models in terms of segmentation accuracy (2% Dice), mem- ory usage (at least 50%), and the number of network parameters (up to 70%). Keywords: Deep neural networks · skip connections · image segmenta- tion 1 Introduction Recent works in image segmentation has shown that stacking tens of convolu- tional layers generally perform better than shallower networks [7], due to their ability to learn more complex nonlinear functions. However, they are difficult to train because of the high number of parameters and gradient vanishing. One of the recent key ideas to effectuate the training process and handle gradient van- ishing is to introduce skip connections between subsequent layers in the network, which has been shown to improve some of the encoder-decoder segmentation net- works (e.g., 2D U-Net [19], 3D U-Net [2], 3D V-Net [18], and The One Hundred Layers Tiramisu (DensNet) [12]). Skip connections help in the training process by recovering spatial information lost during down-sampling, as well as reducing

Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

  • Upload
    others

  • View
    2

  • Download
    0

Embed Size (px)

Citation preview

Page 1: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

Select, Attend, and Transfer: Light, LearnableSkip Connections

Saeid Asgari Taghanaki1,2, Aicha Bentaieb1, Anmol Sharma1, S. Kevin Zhou2,Yefeng Zheng2, Bogdan Georgescu2, Puneet Sharma2, Zhoubing Xu2, Dorin

Comaniciu2, and Ghassan Hamarneh1

1 Medical Image Analysis Lab, School of Computing Science, Simon FraserUniversity, Canada

2 Medical Imaging Technologies, Siemens Healthineers, Princeton, NJ, [email protected]

Abstract. Skip connections in deep networks have improved both seg-mentation and classification performance by facilitating the training ofdeeper network architectures and reducing the risks for vanishing gra-dients. The skip connections equip encoder-decoder like networks withricher feature representations, but at the cost of higher memory usage,computation, and possibly resulting in transferring non-discriminativefeature maps. In this paper, we focus on improving the skip connectionsused in segmentation networks. We propose light, learnable skip connec-tions which learn to first select the most discriminative channels, andthen aggregate the selected ones as single channel attending to the mostdiscriminative regions of input. We evaluate the proposed method on 3different 2D and volumetric datasets and demonstrate that the proposedskip connections can outperform the traditional heavy skip connectionsof 4 different models in terms of segmentation accuracy (2% Dice), mem-ory usage (at least 50%), and the number of network parameters (up to70%).

Keywords: Deep neural networks · skip connections · image segmenta-tion

1 Introduction

Recent works in image segmentation has shown that stacking tens of convolu-tional layers generally perform better than shallower networks [7], due to theirability to learn more complex nonlinear functions. However, they are difficult totrain because of the high number of parameters and gradient vanishing. One ofthe recent key ideas to effectuate the training process and handle gradient van-ishing is to introduce skip connections between subsequent layers in the network,which has been shown to improve some of the encoder-decoder segmentation net-works (e.g., 2D U-Net [19], 3D U-Net [2], 3D V-Net [18], and The One HundredLayers Tiramisu (DensNet) [12]). Skip connections help in the training processby recovering spatial information lost during down-sampling, as well as reducing

Page 2: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

2 Asgari Taghanaki et al.

the risks of vanishing gradients [11]. However, direct transfer of feature mapsfrom previous layers to subsequent layers may also lead to redundant and non-discriminatory feature maps being transferred. Also, as the transferred featuremaps are concatenated to the feature maps in subsequent layers, the memoryusage increases many folds.Complexity reduction. Recently, there have been several efforts to reduce thetraining and runtime computations of deep classification networks [9,16]. A fewother works have attempted to simplify the structure of deep networks, e.g., bytensor factorization [13], channel/network pruning [23] or applying sparsity toconnections [6]. However, the non-structured connectivity and irregular memoryaccess, which is caused by sparsity regularization and connection pruning meth-ods, adversely impacts practical speedup [23]. On the other hand, tensor factor-ization is not compatible with the recently designed networks, e.g., GoogleNetand ResNet, and many such methods may even end up with more computationsthan the original architectures [8].Gates and attention. Attention can be viewed as using information transferredfrom several subsequent layers/feature maps to localize the most discriminative(or salient) part of the input signal. Srivastava et al. [21] modified ResNet ina way to control the flow of information through a connection; their proposedgates control the level of contribution between unmodified input and activationsto a consecutive layer. Hu et al. [10] proposed a selection mechanism where fea-ture maps are first aggregated using global average pooling and reduced to asingle channel descriptor, then an activation gate is used to highlight the mostdiscriminative features. Recently, Wang et al. [22] added an attention moduleto ResNet for image classification. Their proposed attention module consistsof several encoding-decoding layers, which although helped in improving imageclassification accuracy, also increased the computational complexity of the modelby an order of magnitude [22]. Lin et al., adopted one step fusion without fil-tering the channels. Xiao et al. [1] added an non-learnable extra branch, calledhourglass residual units, to the original residual connections.In this paper, we propose a modification of the traditional skip connections, us-ing a novel select-attend-transfer gate, which explicitly enforces learnability andaims at simultaneously improving segmentation accuracy and reducing mem-ory usage and network parameters (Fig. 1). We focus on skip connections inencoder-decoder architectures (i.e. as opposed to skip connections in residualnetworks) designed for segmentation tasks. Our proposed select-attend-transfergate favours sparse feature representations and uses this property to select andattend to the most discriminative feature channels and spatial regions within askip connection. Specifically, we first learn to identify the most discriminativefeature maps in a skip connection, using a set of trainable weights under sparsityconstraint. Then, we reduce the feature maps to a single channel using a convo-lutional filter followed by an attention layer to identify salient spatial locationswithin the produced feature map. This compact representation forms the finalfeature map of the skip connection.

Page 3: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

Light, Learnable Skip Connections 3

2 The Select-Attend-Transfer (SAT) gate

Notation: We define f as an input feature map of size H×W×D×C where H,W , D, and C refer to the height, width, and depth of the volumetric data, andnumber of channels, respectively. The notation in the paper is defined for 3D(volumetric) input images but the method can be easily adapted to 2D imagesby removing an extra dimension and applying 2D convolutions instead of 3D. Anoverview of the proposed SAT gate is shown in Fig. 1. It consists of the followingmodules: 1) Select: re-weighting the channels of the input feature maps f , using alearned weight vector with sparsity constraints, to encourage sparse feature maprepresentations, that is, only those channels with non-zero weights are selected;2) Attend: discovering the most salient spatial locations within the final featuremap; and 3) Transfer: transferring the output of the gate into subsequent layersvia a skip connection.

U fa

f1

f2

f3

fC

Con

v

Con

v

Con

v

Con

v

C

Inpu

t

Out

put fS,1

fS,2

fS,3

fS,C

W×Conv 1x1x1

ReLU(W) = max(0, min(W, 1))

Sigmoid

S A T

Con

v

Con

v

Con

v

Con

v

C

Inpu

t

Out

put

SAT

Fig. 1: The proposed select-attend-transfer (SAT) gate placed in a dense block for avolumetric 3D data. The different shades of green colour show the result of applyingthe proposed soft channel selection, i.e., some channels can be totally turned off (grey)or partially turned on.

Selection via sparsification. We weight the channels of an input feature mapf by using a scalar weight vector W trained along with all other network pa-rameters. The weight vector is defined such that we encourage sparse featuremaps, resulting in completely/partially turning off or on feature maps. Insteadof relying on complex loss functions, we clip negative weights to zero and pos-itives to at most one using a truncated ReLU function. Each channel of thefeature map block f is multiplied by W as fs,c = fc ∗ trelu(Wc), where Wc isa scalar weight value associated with the input feature map channel fc; trelu(.)is the truncated ReLU function. The output fs is of size H ×W ×D×C and isa sparse representation of f . Zero weights turn off corresponding feature mapscompletely, whereas positive weights fully/partially turn on features maps; i.e.,implementing the soft feature map selection.

Attention via filtering. The output fs is further filtered to identify the mostdiscriminative linear combination of feature channels C. For this, we employ aconvolution filter K of size 1 × 1 × 1 ×C (1 × 1 ×C for 2D data), which learns

Page 4: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

4 Asgari Taghanaki et al.

how best to aggregate the different channels of the feature map fs. The output(i.e. U) of this feature selection step reduces fs to K ?fs of size H×W ×D×1.where ? is the convolution operation. To identify salient spatial locations withinthe U , we introduce an attention gate fa = σ(U) composed of a sigmoid ac-tivation function (σ). Using the sigmoid as an activation function allows usto identify multiple discriminative spatial locations (as opposed softmax whichidentifies one single location) within the feature map U . The computed fa formsa compact summary of the input feature map f .Transfer via skip connection. The computed fa is transferred to subsequentlayers via a skip connection.Special examples. There are two special cases of the proposed SAT gate. Oneis the ST gate which skips the A part, that is, only channel selection but noattention is performed. The signal fs is directly fed to subsequent layers. Theother is the AT gate which skips the S part by setting all weights to one. Thisway, there is no channel selection and only attention is performed.Training and implementation details. To set the hyperparameters, we startedwith the proposed values mentioned in U-Net, V-Net, and DensNet papers. How-ever, we found experimentally that applying Adadelta optimizer (with its defaultparameters: η = 1, ρ = 0.95, ε = 1e − 8, and decay = 0) with Glorot uniformmodel initializer, works best for all the networks. After each convolution layerwe use batch normalization.

3 Experiments

In this section, we evaluate the performance of the proposed method on threecommonly used fully convolutional segmentation networks which leverage skipconnections: U-Net (both 2D and 3D), V-Net (both 2D and 3D), and the Den-sNet network (both 2D and 3D). We test the proposed method on three datasets(Fig. 2) including (i) a magnetic resonance imaging (MRI) dataset; (ii) a skinlesion dataset; and (iii) a computed tomography (CT) dataset. To analyze the

CT MRI Skin

Fig. 2: Three samples of the used datasets. The liver (contoured in red) in CT scan,the prostate in an MRI scan, a sample skin images containing lesions highlighted inred from left to right.

performance of the proposed method, we performed the following experiments:a) We tested the performance of the proposed method, in terms of segmentationaccuracy, on datasets i and ii. b) We applied the proposed SAT gate to the re-cently introduced method deep image to image network (DI2IN) [24] for liversegmentation method (Section 3.2). c) We quantitatively and qualitatively ana-lyzed the proposed skip connections in terms of the amount of data transferred

Page 5: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

Light, Learnable Skip Connections 5

and we visualized the outputs of both channel selection and attention layers.We also compared the proposed method vs. the original networks in terms ofmemory usage and number of parameters (Section 3.3).

3.1 Volumetric and 2D binary segmentation

Volumetric CT liver segmentation: In this experiment, the goal is to seg-ment the liver from CT images. We used more than 2000 CT scans (to the bestof our knowledge this is the largest CT dataset used in the literature so far) ofdifferent resolutions collected internally and from The Cancer Imaging Archive(TCIA) QIN-HEADNECK dataset [4], which were resampled to isotropic voxelsof size 2 × 2 × 2 (mm). We picked 61 volumes of whole dataset for testing andtrained the networks on the remaining volumes.Volumetric MRI prostate segmentation: In this experiment, we test theproposed method on a volumetric MRI prostate dataset of 1029 MRI volumes ofdifferent resolutions which were collected internally and from TCIA ProstateXdataset [5] that were resampled to isotropic voxels of size 2 × 2 × 2 (mm). Weused 770 images for training and 258 images for testing.2D RGB skin lesion segmentation: For this experiment, we used the 2DRGB skin lesion dataset from the 2017 IEEE ISBI International Skin ImagingCollaboration (ISIC) Challenge [3]. We train on a dataset of 2000 images andtest on a different set of 150 images. As reported in Table 1, overall, equippingU-Net and V-Net with SAT improved Dice results for all 4 (2 modalities times2 networks) experiments. Specifically: (i) For the MRI dataset, the SAT gateimproves U-Net performance by 1.15% (0.87 to 0.88) in Dice. Using SAT withV-Net improved Dice results by 2.4% (0.85 to 0.87). Note that, although forV-Net (MRI data) the Dice improvement is small, instead of transferring all thechannels, the proposed method transfers only one attention map, which reducesmemory usage to a high extent. (ii) For the skin dataset, equipping U-Net withSAT improved Dice by 2.53% (0.79 to 0.81) , and similarly V-Net with SATresulted in 2.5% (0.81 to 0.83) improvement in terms of Dice.

To select the winner in Table 1, we consider different criteria: accuracy, num-ber of parameters, and memory usage. We argue that the winner is SAT as withSAT we are able to reduce memory and # params without sacrificing accuracy.When there is no attention, i.e. only ST, multiple channels are transferred whichimplies higher memory usage and # params, and our results show that, eventhen, there is no clear gain in accuracy compared to AT and SAT. That said, STdoes outperform standard models (i.e. U-Net, V-Net, and DensNet). When thereis attention but no multiple channel selection, i.e. only AT, although we manageto reduce memory and # params, the results show erroneous segmentations andlower Dice (Table 1). So, the main competition is between the AT and SAT asboth transfer a single channel resulting in reduction in # params and memoryusage. However, looking at the quantitative results in Table 1, SAT outperformsAT in terms of Dice in 3 out of 4 cases. Also, looking at the qualitative resultsin Figures 3, SAT wins as AT tends to attend to several wrong areas for dif-ferent datasets. As reported by Gao et al. [11], DenseNet outperforms Highway

Page 6: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

6 Asgari Taghanaki et al.

Table 1: U-Net and V-Net results of prostate and skin lesion segmentation. N is thetotal number of channels C before skip connection.

Data Model Method # C Dice

MRI

3D U-Net [2]

ORIG C = N 0.87±0.05AT C = 1 0.88±0.05ST C 6 N 0.88±0.05SAT C = 1 0.880.880.88±0.050.050.05

3D V-Net [18]

ORIG C = N 0.85±0.05AT C = 1 0.86±0.05ST C 6 N 0.86±0.05SAT C = 1 0.870.870.87±0.040.040.04

Skin

2D U-Net [19]

ORIG C = N 0.79±0.001AT C = 1 0.79±0.001ST C 6 N 0.79±0.001SAT C = 1 0.810.810.81±0.0010.0010.001

2D V-Net

ORIG C = N 0.81±0.001AT C = 1 0.82±0.002ST C 6 N 0.82±0.001SAT C = 1 0.830.830.83±0.0010.0010.001

Networks [21], Network in Network [17], ResNet [7], FractalNet [14], Deeply Su-pervised Nets [15], and ”Striving for simplicity: The all convolutional net” [20],all of which attempted to optimize the data flow and number of parameters indeep networks. We tested the proposed SAT gate on the segmentation versionof the winner model i.e. The One Hundred Layers Tiramisu network [12] (i.e.DensNet). Note that we apply the proposed SAT gate to more skip connectionsin this network i.e. both long skip connections between the encoder and decoderand the long skip connections inside of each dense block (Figure 1). While ourproposed SAT gate reduces the number of parameters in DensNet by ∼ 70%, itimproves Dice results by 3.7% (0.79±0.12 to 0.82±0.07) for MRI and by 2.5%(0.79±0.002 to 0.81±0.001) for Skin.

As the next experiment, we applied our proposed SAT gate to the DI2IN [24]method for liver segmentation from CT images and achieved the same perfor-mance as the original DI2IN network i.e. Dice score of 0.96 while reducing thenumber of the whole network parameters by 12% by applying SAT to only skipconnections and reducing number of channels for each skip connection to onlyone channel.

3.2 Quantitative and qualitative analysis of the proposed skipconnections

In this section, we visualize the outputs of both the channel selection and at-tention layers. As shown in Fig. 3, after applying the selection step, some of theless discriminative channels are completely turned off by the proposed channelselection. As can be seen in both Fig. 3, a model with only attention layer (i.e.,

Page 7: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

Light, Learnable Skip Connections 7

only AT) tends to attend to several areas of the image; including both wherethe object is present and absent (note the red colour visible over the whole im-age). However, applying channel selection (i.e., ST) before the attention layercurtails the model from attending to less discriminative regions. We also quanti-

Fig. 3: Attention maps, segmentation, and ground truth for 2 skip connections of 2DU-Net for only attention (AT) and channel selection + attention (SAT) for skin images.

tatively analyzed the proposed learnable skip connections in terms of percentageof channels turned off (i.e., channels i for which wi is zero in the channel selectionlayer). For U-Net, 55.7±5 and 66.7±7 and for V-Net, 57.0±12.9 and 54.5±6.9percentage of channels were off for MRI and skin datasets, respectively. Afterapplying SAT, the number of parameters of only the skip connections is reducedby 28%, 50%, and 70% for U-Net, V-Net, and DensNet networks. Since trans-ferring only one channel (instead of N channels) as the output of the SAT gate,reduces the needed number of convolution kernel parameters in the other side ofskip connections, we further report the total number of parameters before andafter applying the proposed method (i.e., SAT).

Further note that the proposed method reduces the number of convolutionoperations after each concatenation to almost 50% as demonstrated next. As anexample, after an original skip connection that carries 256 channels and con-catenates them with another 256 channels on the other side of the network,the consequent convolution layer right after the concatenation will need 512(=256+256) convolution operations. However, as the proposed skip connectionscarry only one channel, for the same example, only 257 (=1+256) convolutionsare needed.

4 Conclusions

We proposed a novel architecture for skip connections in fully convolutional seg-mentation networks. Our proposed skip connection involves a channel selectionstep followed by an attention gate. Equipping popular segmentation networkswith the proposed skip connections allowed us to reduce computations and net-work parameters (the proposed method transfer only one unique feature channel

Page 8: Select, Attend, and Transfer: Light, Learnable Skip ...hamarneh/ecopy/miccai_mlmi2019b.pdf · S ,C W! Conv 1x1x1 ReLU(W) = max(0, min(W, 1 )) Sigmoid" # $ É onv onv onv onv C nput

8 Asgari Taghanaki et al.

instead of many), improve the segmentation results (it attends to the most dis-criminative channels and regions within the feature maps of a skip connection)and consistently obtain more accurate segmentation results.

References

1. Chu, X., et al.: Multi-context attention for human pose estimation. arXiv preprintarXiv:1702.07432 1(2) (2017)

2. Cicek, O., et al.: 3D U-Net: learning dense volumetric segmentation from sparseannotation. In: MICCAI. pp. 424–432. Springer (2016)

3. Codella, N.C., et al.: Skin lesion analysis toward melanoma detection: ISBI, hostedby ISIC. arXiv preprint arXiv:1710.05006 (2017)

4. Fedorov, A., et al.: DICOM for quantitative imaging biomarker develop-ment: a standards based approach to sharing clinical data and structuredPET/CT analysis results in head and neck cancer research. peerj 4:e2057https://doi.org/10.7717/peerj.2057 (2016)

5. Geert, L., et al.: Prostatex challenge data. The Cancer Imaging Archive (2017)6. Han, S., et al.: EIE: efficient inference engine on compressed deep neural network.

In: ISCA, ACM/IEEE. pp. 243–254. IEEE (2016)7. He, K., et al.: Deep residual learning for image recognition. In: CVPR. pp. 770–778

(2016)8. He, Y., et al.: Channel pruning for accelerating very deep neural networks. In:

ICCV. vol. 2, p. 6 (2017)9. Howard, A.G., et al.: Mobilenets: Efficient convolutional neural networks for mobile

vision applications. arXiv preprint arXiv:1704.04861 (2017)10. Hu, J., et al.: Squeeze-and-excitation networks. arXiv preprint arXiv:1709.01507

(2017)11. Huang, G., et al.: Densely connected convolutional networks. In: CVPR (2017)12. Jegou, S., et al.: The one hundred layers tiramisu: Fully convolutional densenets

for semantic segmentation. In: CVPRW (2017)13. Kim, Y.D.a.o.: Compression of deep convolutional neural networks for fast and low

power mobile applications. arXiv preprint arXiv:1511.06530 (2015)14. Larsson, G., et al.: Fractalnet: Ultra-deep neural networks without residuals. arXiv

preprint arXiv:1605.07648 (2016)15. Lee, C.Y., et al.: Deeply-supervised nets. In: AIS. pp. 562–570 (2015)16. Leroux, S., et al.: Iamnn: Iterative and adaptive mobile neural network for efficient

image classification. arXiv preprint arXiv:1804.10123 (2018)17. Lin, M., et al.: Network in network. arXiv preprint arXiv:1312.4400 (2013)18. Milletari, F., et al.: V-net: Fully convolutional neural networks for volumetric med-

ical image segmentation. In: 3DV, 2016. pp. 565–571. IEEE (2016)19. Ronneberger, O., et al.: U-net: Convolutional networks for biomedical image seg-

mentation. In: MICCAI. pp. 234–241. Springer (2015)20. Springenberg, J.T., et al.: Striving for simplicity: The all convolutional net. arXiv

preprint arXiv:1412.6806 (2014)21. Srivastava, R.K., et al.: Highway networks. arXiv preprint arXiv:1505.00387 (2015)22. Wang, F., et al.: Residual attention network for image classification. arXiv preprint

arXiv:1704.06904 (2017)23. Wen, W., et al.: Learning structured sparsity in deep neural networks. In: NIPS.

pp. 2074–2082 (2016)24. Yang, D., et al.: Automatic liver segmentation using an adversarial image-to-image

network. In: MICCAI. pp. 507–515. Springer (2017)