tf.strided_slice()

hearthougan發表於2018-11-06
# -*- coding: utf-8 -*-
"""
Created on Tue Nov  6 10:46:32 2018

@author: Abner_hg
"""
import numpy as np

import tensorflow as tf


data = [[[1, 1, 1], [2, 2, 2]],
            [[3, 3, 3], [4, 4, 4]],
            [[5, 5, 5], [6, 6, 6]]]

x = tf.strided_slice(data,[0,0,0],[1,1,1])
y = tf.strided_slice(data,[0,0,0],[2,2,2],[1,1,1])
z = tf.strided_slice(data,[0,0,0],[2,2,2],[1,2,1])

with tf.Session() as sess:
    print('x = ', sess.run(x))
    print('y = \n', sess.run(y))
    print('z = \n', sess.run(z))

輸出:

x =  [[[1]]]
y = 
 [[[1 1]
  [2 2]]

 [[3 3]
  [4 4]]]
z = 
 [[[1 1]]

 [[3 3]]]

那麼 tf.strided_slice()對data,怎麼切的呢?首先看一下, tf.strided_slice()的引數。

strided_slice(
    input_,
    begin,
    end,
    strides=None,
    begin_mask=0,
    end_mask=0,
    ellipsis_mask=0,
    new_axis_mask=0,
    shrink_axis_mask=0,
    var=None,
    name=None
)

主要的引數有input、begin、end以及strides四個。其中begin、input和strides要與input的維數保持一致。begin,end和strides共同決定input的每一維是如何剪下的。注意,這裡end是開區間。那麼如何切片的,接下來,我將以上述x和z例子做分析。

data = 
            [#第一維
               [#第二維
                  [1, 1, 1], #第三維
                  [2, 2, 2]
               ],
              
               [
                  [3, 3, 3],
                  [4, 4, 4]
               ],
               
               [
                  [5, 5, 5], 
                  [6, 6, 6]
               ] 
            ]

1、x = tf.strided_slice(data,[0,0,0],[1,1,1])

其中input_ = data, begin = [0,0,0],end = [1,1,1],strides = None。

(1)、首先第一維是begin=0,end=1,也就是從外向內,對第一個括號取第一個元素,即:

data = 
            [#第一個括號,
               [
                  [1, 1, 1],
                  [2, 2, 2]
               ],#第一個括號的第一個元素
              
               [
                  [3, 3, 3],
                  [4, 4, 4]
               ],#第一個括號的第二個元素
               
               [
                  [5, 5, 5], 
                  [6, 6, 6]
               ] #第一個括號的第三個元素
            ]


=====>

x1 = 
            [#第一個括號,
               [
                  [1, 1, 1],
                  [2, 2, 2]
               ],#第一個括號的第一個元素
            ]

(2)、第二維是begin=0,end=1,也就是從外向內,對第二個括號取第一個元素,即:

x1 = 
            [
               [#第二個括號,
                  [1, 1, 1],#第二個括號的第一個元素
                  [2, 2, 2] #第二個括號的第二個元素
               ]
            ]

=====>

x2 = 
            [
               [#第二個括號,
                  [1, 1, 1]#第二個括號的第一個元素
               ]
            ]

(3)、第二維是begin=0,end=1,也就是從外向內,對第三個括號取第一個元素,即:

x2 = 
            [
               [
                  [1, 1, 1]#第三個括號
                #第 1  2  3個元素
               ]
            ]
=====>

x3 = 
            [
               [
                  [1]
               ]
            ]

2、z = tf.strided_slice(data,[0,0,0],[2,2,2],[1,2,1])

其中input_ = data, begin = [0,0,0],end = [2,2,2],strides = [1,2,1]。

(1)、首先第一維是begin=0,end=2,stride = 1也就是從外向內,對第一個括號取第一個第二個元素,即:

data = 
            [#第一個括號,
               [
                  [1, 1, 1],
                  [2, 2, 2]
               ],#第一個括號的第一個元素
              
               [
                  [3, 3, 3],
                  [4, 4, 4]
               ],#第一個括號的第二個元素
               
               [
                  [5, 5, 5], 
                  [6, 6, 6]
               ] #第一個括號的第三個元素
            ]


=====>

x1 = 
            [#第一個括號,
               [
                  [1, 1, 1],
                  [2, 2, 2]
               ]#第一個括號的第一個元素
              
               [
                  [3, 3, 3],
                  [4, 4, 4]
               ]#第一個括號的第二個元素
               
            ]

(2)、第二維是begin=0,end=2,stride = 2,也就是從外向內,對所有第二個括號取第一個元素,注意:end是開區間,即:

x1 = 
            [
               [#第二個括號,
                  [1, 1, 1],#第二個括號的第一個元素
                  [2, 2, 2] #第二個括號的第二個元素
               ],
              
               [
                  [3, 3, 3],#第二個括號的第一個元素
                  [4, 4, 4] #第二個括號的第二個元素
               ]
               
            ]


=====>

x2 = 
            [
               [#第二個括號,
                  [1, 1, 1] #第二個括號的第一個元素

               ],
              
               [
                  [3, 3, 3] #第二個括號的第一個元素

               ]
               
            ]

(3)、首先第三維是begin=0,end=2,stride = 1也就是從外向內,對所有第三個括號取第一個第二個元素,即:

x2 = 
            [
               [
                  [1, 1, 1] #第三個括號,
                #第 1  2  3個元素
               ],
              
               [
                  [3, 3, 3] #第三個括號,
                #第 1  2  3個元素
               ]
               
            ]

=====>

x3 = 
            [
               [#第二個括號,
                  [1, 1] #第二個括號的第一個元素
               #第 1  2 個元素
               ],
              
               [
                  [3, 3] #第二個括號的第一個元素
               #第 1  2 個元素
               ]
               
            ]