个性化阅读
专注于IT技术分析

PyTorch均方误差图解分析

均方误差的计算方法与之前的一般损耗方程式大致相同。我们还将考虑偏差值, 因为这也是在训练过程中需要更新的参数。

(y-Ax + b)2

均方误差最好用图示说明。

假设我们有一组值, 我们首先像以前一样绘制一些回归线参数, 其大小由一组随机的权重和偏差值确定。

均方误差

误差对应于实际值与预测值之间的距离-它们之间的实际距离。

均方误差

对于每个点, 通过使用以下公式将线模型得出的预测值与实际值进行比较来计算误差

均方误差

每个点都与一个错误相关联, 这意味着我们必须对每个点进行误差求和。我们知道预测可以重写为

均方误差

在计算均方误差时, 我们必须通过除以数据点数得出平均值。现在前面提到的误差函数的梯度应该将我们带入误差最大增加的方向。

朝着成本函数的梯度的负值移动, 我们朝着最小误差的方向移动。我们将使用此坡度作为指南针, 使我们始终走下坡路。在梯度下降中, 我们忽略了偏差的存在, 但是对于误差, 参数A和b都需要定义。

现在, 我们下一步将为每个值计算偏导数, 并且像从任何A和b值对开始一样。

均方误差

基于上面提到的两个偏导数, 我们使用梯度下降算法在最小误差方向上更新A和b。对于每次迭代, 新的权重等于

A1 = A0-∝ f'(A)

并且新的偏差值等于

b1 = b0-∝ f'(b)

编写代码的主要思想, 即, 我们从具有随机一组权重和偏差值参数的随机模型开始。该随机模型倾向于具有较大的误差函数, 较大的成本函数, 然后使用梯度下降沿最小误差的方向更新模型的权重。最小化该错误以返回最佳结果。


赞(0)
未经允许不得转载:srcmini » PyTorch均方误差图解分析

评论 抢沙发

评论前必须登录!