AI Research

PyTorch Deep Learning Framework: Speed + Usability

A new paper from original PyTorch developers explores the inspiration behind the library, and makes the case for its unique marriage of speed and usability.

Deep learning has achieved human-level performance on reading radiology scans, describing images with idiomatic sentences, playing complex games, and more. Deep learning is however compute-intensive, and in the development of deep learning frameworks something of a trade-off has emerged: increasing usability tends to negatively affect speed, and vice versa.

Popular frameworks Caffe, Tensorflow and Theano provide quick computing performance but at the cost of ease of use and flexibility. Then there’s PyTorch, developed primarily by Facebook AI and introduced in 2016. A new paper from original PyTorch developers Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan and 17 other researchers explores the inspiration behind the library, and makes the case for its unique marriage of speed and usability.

PyTorch is a highly efficient library for facilitating the building of deep learning projects. It allows deep learning models to be expressed in the idiomatic Python programming language, which is a huge plus for usability.

Four main principles informed the PyTorch design. Firstly, considering that most data scientists are familiar with Python, PyTorch stayed within this ecosystem to keep the interfaces simple and consistent. Secondly, it hid much of the complexity inherent to machine learning behind intuitive APIs. Thirdly, it allowed researchers to manually control the execution of their code to provide pragmatic performance. Lastly, PyTorch implementation was designed to be simple but slightly “incomplete,” as opposed to more comprehensive and complex frameworks. This makes it easier to implement additional features and adapt to new situations in PyTorch, thus increasing its flexibility.

PyTorch applies a graph-metaprogramming based approach — the executive code for defining layers, composing models, loading data and running optimizer is all expressed by general purpose programming. Such a design ensures that any new neural network architecture can be easily generalized and implemented with PyTorch.

Interoperability and extensibility are other top priorities for PyTorch, which enable bidirectional exchange of data with external libraries. Users are free to replace PyTorch components to better serve their specific project needs.

PyTorch’s impressive performance was achieved in large part due to the following five strategies:

  1. The PyTorch core is used to implement tensor data structure, CPU and GPU operators, basic parallel primitives and automatic differentiation calculations. As the most intensive computing operations are handled by the core, they can be written in the efficient C++ programming language to boost performance.
  2. A strict separation is maintained between control and data flow. Although the control flow is in Python, the optimized C++ code can be executed on the host CPU.
  3. A custom allocator incrementally builds up a cache of CUDA memory and reassigns it to latter allocations, thus preventing further use of CUDA APIs.
  4. Torch.multiprocessing enables users to implement heavily parallel programs on multiple GPUs.
  5. A reference counting scheme monitors the number of uses of each tensor — once the count reaches zero, the underlying memory is immediately freed.
Deep learning framework training speed for six models

When compared with mainstream deep learning frameworks Chainer, CNTK,MXNet, PaddlePaddel, and TensorFlow, PyTorch scored within 17 percent of the fastest framework on all six benchmarks.

The paper PyTorch: An Imperative Style, High-Performance Deep Learning Library is on arXiv.

Author: Hecate He | Editor: Michael Sarazen

1 comment on “PyTorch Deep Learning Framework: Speed + Usability

  1. Pingback: PyTorch Deep Learning Framework: Speed + Usability – rajanarya

Leave a Reply

Your email address will not be published. Required fields are marked *

%d bloggers like this: