一、 均值 mean
- 基本 API
torch.mean(input, dim=None, keepdim=False, *, dtype=None)x.mean(dim=None, keepdim=False)
- 参数说明
input / x:输入 tensordim:指定在哪些维度上求均值- None:对所有元素求均值
int或tuple:指定维度
keepdim:是否保留被 reduce 的维度dtype:计算时使用的数据类型
二、方差 var
- 基本 API
torch.var(x, dim=None, keepdim=False, unbiased=True)x.var(dim=None, keepdim=False, unbiased=True)
- 参数说明
dim:计算方差的维度,None 表示全部维度unbiased:无偏估计(True)还是总体方差(False)keepdim:是否保持维度
三、统计量中 dim 的影响
可以简单的理解成 dim 指定的维度不会体现在结果中,而是会被 reduce。
比如对于输入 NCHW 的张量来说,如果计算统计量时 dim=0,则表示沿着 Batch 维度计算,结果的尺寸为 CHW;再比如对于 None 来说,则表示对于全部维度进行计算,那么结果的尺寸就是 1(scalar)。
如果结合公式的话,可以体现为 dim 指定的维度都会出现在平均的分母中,比如 x.mean(dim=[0,2,3]),输出尺寸是 (C,),其中每一个元素(位置 c)的计算公式为: