Bootstrap

PyTorch:常见错误 inplace operation

操作是 PyTorch 里面一个比较常见的错误,有的时候会比较好发现,例如下面的代码:

 import torch
 w = torch.rand(4, requires_grad=True)
 w += 1
 loss = w.sum()
 loss.backward()

执行 对参数 进行求导,会出现报错:

导致这个报错的主要是第 3 行代码 ,如果把这句改成 ,再执行就不会报错了。这种写法导致的 是比较好发现的,但是有的时候同样类似的报错,会比较不好发现。例如下面的代码:

 import torch
 x = torch.zeros(4)
 w = torch.rand(4, requires_grad=True)
 x[0] = torch.rand(1) * w[0]
 for i in range(3):
     x[i+1] = torch.sin(x[i]) * w[i]
 loss = x.sum()
 loss.backward()

执行之后会出现报错:

 >>> RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: 
[torch.FloatTensor []], which is output 0 of SelectBackward, is at version 4; expected version 3 instead. 
Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

根据提示我们可以使用 来帮助我们定位具体的出错位置(这个方法会花费比较长的时间)。

 with torch.autograd.set_detect_anomaly(True):
     x = torch.zeros(4)
     w = torch.rand(4, requires_grad=True)
     x[0] = torch.rand(1) * w[0]
     for i in range(3):
         x[i+1] = torch.sin(x[i]) * w[i]
     loss = x.sum()
     loss.backward()

运行会增加这些报错:

 >>> /Users/strongnine/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:130: 
UserWarning: Error detected in SinBackward. Traceback of forward call that caused the error:

可以看到出现了 ,这句描述,我们可以猜测大概是 这个函数出现了问题。实际上,这个报错的解决办法,就是将第 6 行代码 改成 ,就行了。

 import torch
 x = torch.zeros(4)
 w = torch.rand(4, requires_grad=True)
 x[0] = torch.rand(1) * w[0]
 for i in range(3):
     x[i+1] = torch.sin(x[i].clone()) * w[i]
 loss = x.sum()
 loss.backward()

总结一下,遇到 的报错,一般可以通过:

  • 改成 ;

  • 改成 ;

  • 改成 ;

如果自己检查不出是哪里出现了问题,可以使用 来帮助我们定位具体的出错位置,但是要注意的是这个方法一般会运行比较长的时间。

参考: