官术网_书友最值得收藏!

How to do it...

  1. First, we install PyTorch in our Anaconda environment, as follows:
conda install pytorch torchvision cuda80 -c soumith

If you want to install PyTorch on another platform, you can have a look at the PyTorch website for clear guidance: http://pytorch.org/.

  1. Let's import PyTorch into our Python environment:
import torch
  1. While Keras provides higher-level abstraction for building neural networks, PyTorch has this feature built in. This means one can build with higher-level building blocks or can even build the forward and backward pass manually. In this introduction, we will use the higher-level abstraction. First, we need to set the size of our random training data:
 batch_size = 32
input_shape = 5
output_shape = 10
  1. To make use of GPUs, we will cast the tensors as follows:
torch.set_default_tensor_type('torch.cuda.FloatTensor') 

This ensures that all computations will use the attached GPU. 

  1. We can use this to generate random training data:
from torch.autograd import Variable
X = Variable(torch.randn(batch_size, input_shape))
y = Variable(torch.randn(batch_size, output_shape), requires_grad=False)
  1. We will use a simple neural network having one hidden layer with 32 units and an output layer:
model = torch.nn.Sequential(
torch.nn.Linear(input_shape, 32),
torch.nn.Linear(32, output_shape),
).cuda()

We use the .cuda() extension to make sure the model runs on the GPU. 

  1. Next, we define the MSE loss function:
loss_function = torch.nn.MSELoss()
  1. We are now ready to start training our model for 10 epochs with the following code:
learning_rate = 0.001
for i in range(10):
y_pred = model(x)
loss = loss_function(y_pred, y)
print(loss.data[0])
# Zero gradients
model.zero_grad()
loss.backward()

# Update weights
for param in model.parameters():
param.data -= learning_rate * param.grad.data
The PyTorch framework gives a lot of freedom to implement simple neural networks and more complex deep learning models. What we didn't demonstrate in this introduction, is the use of dynamic graphs in PyTorch. This is a really powerful feature that we will demonstrate in other chapters of this book.
主站蜘蛛池模板: 贵阳市| 盘山县| 长寿区| 东乌珠穆沁旗| 平定县| 东宁县| 石景山区| 鄂托克前旗| 靖宇县| 桃江县| 政和县| 巴林右旗| 阳城县| 文化| 克拉玛依市| 土默特左旗| 运城市| 高阳县| 皮山县| 桑日县| 廉江市| 定南县| 石林| 会宁县| 威远县| 涿州市| 东明县| 尼木县| 浑源县| 松溪县| 陆川县| 天柱县| 丰宁| 大连市| 绍兴市| 灵川县| 廉江市| 宁远县| 镶黄旗| 梅河口市| 甘泉县|