- Python Deep Learning Cookbook
- Indra den Bakker
- 312字
- 2021-07-02 15:43:12
How to do it...
- First, we install PyTorch in our Anaconda environment, as follows:
conda install pytorch torchvision cuda80 -c soumith
If you want to install PyTorch on another platform, you can have a look at the PyTorch website for clear guidance: http://pytorch.org/.
- Let's import PyTorch into our Python environment:
import torch
- While Keras provides higher-level abstraction for building neural networks, PyTorch has this feature built in. This means one can build with higher-level building blocks or can even build the forward and backward pass manually. In this introduction, we will use the higher-level abstraction. First, we need to set the size of our random training data:
batch_size = 32
input_shape = 5
output_shape = 10
- To make use of GPUs, we will cast the tensors as follows:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
This ensures that all computations will use the attached GPU.
- We can use this to generate random training data:
from torch.autograd import Variable
X = Variable(torch.randn(batch_size, input_shape))
y = Variable(torch.randn(batch_size, output_shape), requires_grad=False)
- We will use a simple neural network having one hidden layer with 32 units and an output layer:
model = torch.nn.Sequential(
torch.nn.Linear(input_shape, 32),
torch.nn.Linear(32, output_shape),
).cuda()
We use the .cuda() extension to make sure the model runs on the GPU.
- Next, we define the MSE loss function:
loss_function = torch.nn.MSELoss()
- We are now ready to start training our model for 10 epochs with the following code:
learning_rate = 0.001
for i in range(10):
y_pred = model(x)
loss = loss_function(y_pred, y)
print(loss.data[0])
# Zero gradients
model.zero_grad()
loss.backward()
# Update weights
for param in model.parameters():
param.data -= learning_rate * param.grad.data
The PyTorch framework gives a lot of freedom to implement simple neural networks and more complex deep learning models. What we didn't demonstrate in this introduction, is the use of dynamic graphs in PyTorch. This is a really powerful feature that we will demonstrate in other chapters of this book.
推薦閱讀
- 企業級Java EE架構設計精深實踐
- Java完全自學教程
- Visual Basic程序設計(第3版):學習指導與練習
- Django Design Patterns and Best Practices
- Learning Firefox OS Application Development
- SAS數據統計分析與編程實踐
- Instant Ext.NET Application Development
- Java編程的邏輯
- Maker基地嘉年華:玩轉樂動魔盒學Scratch
- 程序員的成長課
- Spring Data JPA從入門到精通
- 基于MATLAB的控制系統仿真及應用
- Python數據預處理技術與實踐
- LabVIEW入門與實戰開發100例(第4版)
- 深入大型數據集:并行與分布化Python代碼