There are many solutions to this error on the Internet, and the reasons for the error are different. The reason for my error:
def step(self, X, Y):
X = X.flatten(1)
b, _ = X.shape
deltaW = self.lr * (torch.mm(X.T,Y) - self.inhibit)/b
self.weight.data.add_(deltaW)
Here, my X
and Y
are inputs, and the model uses GPU. During input, X
was transferred to the GPU, and Y
forgot. So when calculating deltaW
, there is already a problem.
Just transfer Y
to the GPU. Y.cuda()
or Y.to(device)
.