个性化阅读
专注于IT技术分析

PyTorch中的损失函数实例图解

在上一个主题中, 我们看到该行未正确拟合到我们的数据。为了使其最合适, 我们将使用梯度下降法更新其参数, 但是在此之前, 它需要你了解损失函数。

因此, 我们的目标是找到适合此数据的线的参数。在我们之前的示例中, 线性函数将首先使用以下参数将随机权重和偏差参数分配给我们的行。

损失函数

这条线不能很好地代表我们的数据。我们需要一些优化算法, 该算法将根据总误差来调整这些参数, 直到最终得到包含适当参数的行。

现在, 我们如何确定这些参数?

为了更好的理解, 我们将讨论限制在单个数据点。

通过从实际y值中减去该点的预测值来确定误差。

损失函数

预测值越接近该值, 误差越小。你已经知道的预测可以写成

Ax1+b

但是, 我们正在处理一个点。这样就可以画出无限量的线。为此, 我们消除了偏见。现在删除此额外的自由度, 我们通过将零偏值固定为零来取消它。

(y-y^)2
(y-(Ax+b))2
(y-(Ax+0))2
(y-Ax)2
损失函数

现在, 无论我们要处理的最佳行是哪条线, 其权重都将使此错误尽可能减少到接近零。现在, 我们正在处理点(-3, 3), 对于这种损失, 该函数将转换为

Loss=(3-A(-3))2
Loss=(3+3A)2

现在, 我们创建一个表并尝试使用不同的A值, 看看哪一个给我们最小的误差

损失函数
损失函数
损失函数
损失函数
损失函数

为了可视化目的, 我们在绘图级别中针对不同的权重绘制了不同的误差值。

损失函数

在这种情况下, 绝对最小值对应于负数的权重, 因此我们知道如何评估与线性方程式相对应的误差。

我们如何训练模型知道此处的权重?为此, 我们使用梯度下降。


赞(0)
未经允许不得转载:srcmini » PyTorch中的损失函数实例图解

评论 抢沙发

评论前必须登录!