numpy中轴的理解

最近学numpy中的ndarray,许多函数如np.mean, np.sum, np.min等都有个axis的参数,那什么是轴呢?
其实就是数组的维度,对应ndarray.ndim的值, 也即ndarray.shape的len.

一维数组有一条轴0, 二维数组有两条轴0,1, 三维数组有三条轴0,1,2,依次类推。这与空间的轴概念是对应的,轴代表了变化的方向, 比如一个二维数组:
arr=[[1, 2, 3], [4, 5, 6]]
使用arr[i][j]访问元素,i即为第0轴,j为第1轴,沿0轴变化的元素为:arr[0][j], arr[1][j], 沿1轴变化的元素为:arr[i][0], arr[i][1], arr[i][2]。 这些函数正是在这个方向上进行计算

In [39]: arr
Out[39]:
array([[1, 2, 3],
       [4, 5, 6]])

In [40]: arr.min(axis=0)
Out[40]: array([1, 2, 3]

In [41]: arr.sum(axis=0)
Out[41]: array([5, 7, 9]

In [42]: arr.min(axis=1)
Out[42]: array([1, 4])

In [43]: arr.sum(axis=1)
Out[43]: array([ 6, 15])

二维数组中0轴按列计算,1轴按行计算,并且会降低一维。
再看一个三维的例子:

In [53]: arr
Out[53]:
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]]])

In [54]: arr.sum(axis=0)
Out[54]:
array([[18, 21, 24],
       [27, 30, 33]])

In [55]: arr.sum(axis=1)
Out[55]:
array([[ 3,  5,  7],
       [15, 17, 19],
       [27, 29, 31]])

In [56]: arr.sum(axis=2)
Out[56]:
array([[ 3, 12],
       [21, 30],
       [39, 48]])

axis=0即按arr[0][j][k], a[1][j][k], a[2][j][k]求值, 0 + 6 + 12, 1 + 7 + 13,
axis=1同理arr[i][0][k], a[i][1][k],注意第二维长度为2,即:0+3,1+4, 2+7
axis=2求arr[i][j][0], arr[i][j][1], arr[i][j][2], 0+1+2, 3+4+5


转置也会有轴的概念,也和上面一致,np.transpose接受轴的参数,表示按轴的顺序重新reshape数组,arr.T属性为一种特殊的转置形式,完全倒过来,0,1,2变为2,1,0,看下面输出:
In [79]: arr
Out[79]:
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [80]: arr.shape
Out[80]: (2, 3, 4)

In [81]: arr.transpose(2, 1, 0)
Out[81]:
array([[[ 0, 12],
        [ 4, 16],
        [ 8, 20]],

       [[ 1, 13],
        [ 5, 17],
        [ 9, 21]],

       [[ 2, 14],
        [ 6, 18],
        [10, 22]],

       [[ 3, 15],
        [ 7, 19],
        [11, 23]]])

In [82]: arr.T
Out[82]:
array([[[ 0, 12],
        [ 4, 16],
        [ 8, 20]],

       [[ 1, 13],
        [ 5, 17],
        [ 9, 21]],

       [[ 2, 14],
        [ 6, 18],
        [10, 22]],

       [[ 3, 15],
        [ 7, 19],
        [11, 23]]])

shape由2,3,4变为4,3,2, arr[i][j][k]上的元素变换到arr[k][j][i]上:

In [96]: arr
Out[96]:
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [97]: arrT=arr.T
In [98]: assert arr[0][1][2] == arrT[2][1][0], "should equal"

np.swapaxes 可以指定交换两个轴,道理同上,比如将0,1轴交换:

In [101]: arr
Out[101]:
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [102]: arrS = arr.swapaxes(0, 1)

In [103]: arrS.shape
Out[103]: (3, 2, 4)

In [104]: arrS
Out[104]:
array([[[ 0,  1,  2,  3],
        [12, 13, 14, 15]],

       [[ 4,  5,  6,  7],
        [16, 17, 18, 19]],

       [[ 8,  9, 10, 11],
        [20, 21, 22, 23]]])
作者

BoostMerlin

发布于

2018-01-04

更新于

2023-04-16

许可协议