# Using pytorch+matplotlib to realize linear regression visualization

created at 07-29-2021 views: 15

## libraries and tools¶

import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import datetime


Version

• python 3.7
• torch == 1.9.0
• matplotlib == 3.3.4

## Use pytorch to generate a random tensor?¶

# Select 200 x-axis coordinates, -1 to 1 direct arithmetic sequence
x = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1)
# Randomly select 200 y-axis coordinates
y = 5 * x + 0.8 * torch.rand(x.size())

# Add x, y to Variable
X = Variable(x)
Y = Variable(y)



## Create a model¶

# Number of iterations
epoch = 1000
# Learning rate
learning_rute = 0.0001
# Define model
model = nn.Linear(1, 1)
# Define loss
square_loss = nn.MSELoss(reduction='sum')
# Create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rute)


## running¶

# Let matplotlib enter the interaction
plt.ion()
for i in range(epoch):
# Calculate the predicted value
y_hat = model(X)
# Calculate the loss
loss = square_loss(y_hat, Y)
# Print the loss once every 100 iterations
if (i + 1)% 100 == 0:
print(loss)
# Backpropagation
loss.backward()
# Parameter update
optimizer.step()
# matplotlib window is cleared
plt.cla()
# Draw x,y coordinate points
plt.scatter(X.data.numpy(), Y.data.numpy())
# Draw a straight line
plt.plot(X.data.numpy(), y_hat.data.numpy(),'r-', lw=2)
# Pause for 0.05s for observation
plt.pause(0.05)
# Close interaction
plt.ioff()


## result¶





## Complete code¶


import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import datetime

t_start = datetime.datetime.now()

x = torch.unsqueeze(torch.linspace(-1, 1, 200), dim=1)
print(x.size())
y = 5 * x + 0.8 * torch.rand(x.size())
X = Variable(x)
Y = Variable(y)

epoch = 1000
learning_rute = 0.0001
model = nn.Linear(1, 1)
square_loss = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rute)

plt.ion()

for i in range(epoch):

y_hat = model(X)
loss = square_loss(y_hat, Y)
if (i + 1) % 100 == 0:
print(loss)