一、 均值 mean

  1. 基本 API
  • torch.mean(input, dim=None, keepdim=False, *, dtype=None)
  • x.mean(dim=None, keepdim=False)
  1. 参数说明
  • input / x:输入 tensor
  • dim:指定在哪些维度上求均值
    • None:对所有元素求均值
    • int tuple:指定维度
  • keepdim:是否保留被 reduce 的维度
  • dtype:计算时使用的数据类型

二、方差 var

  1. 基本 API
  • torch.var(x, dim=None, keepdim=False, unbiased=True)
  • x.var(dim=None, keepdim=False, unbiased=True)
  1. 参数说明
  • dim:计算方差的维度,None 表示全部维度
  • unbiased:无偏估计(True)还是总体方差(False)
  • keepdim:是否保持维度

三、统计量中 dim 的影响

可以简单的理解成 dim 指定的维度不会体现在结果中,而是会被 reduce。

比如对于输入 NCHW 的张量来说,如果计算统计量时 dim=0,则表示沿着 Batch 维度计算,结果的尺寸为 CHW;再比如对于 None 来说,则表示对于全部维度进行计算,那么结果的尺寸就是 1(scalar)。

如果结合公式的话,可以体现为 dim 指定的维度都会出现在平均的分母中,比如 x.mean(dim=[0,2,3]),输出尺寸是 (C,),其中每一个元素(位置 c)的计算公式为: