Skip to main content

Why Use Batch Size


In deep learning, batching is an essential concept for efficiently training models, especially when working with large datasets. This article will explain the use of batch processing in PyTorch and how to work with the nn.Linear module.


1. What is Batch Processing?

Batch processing refers to the practice of processing multiple input samples at once, rather than one at a time. This is important for optimizing both training and inference, especially when working with powerful hardware like GPUs.

  • Batch Size (B): The number of samples in a batch. For example, a batch size of 10 means you are processing 10 samples simultaneously.
  • Input Shape (B × Nin): If you have a batch of inputs, the shape will be [B, Nin], where B is the batch size, and Nin is the number of input features per sample.

2. Why Use Batch Size?

Using batch processing comes with several advantages:

  • Efficiency: GPUs are optimized for parallel processing. By using batches, you make use of the full computational power of the GPU, speeding up training and inference.
  • Better Statistics: Advanced models may compute statistics (e.g., mean and variance) over the batch. Larger batch sizes tend to give more accurate statistics, improving model performance.
  • Faster Convergence: Optimizers like stochastic gradient descent (SGD) use the average gradient over the batch to update the model weights, reducing noise and helping the model converge faster.

3. Example with nn.Linear and Batching

Let's explore how batch processing works with nn.Linear in PyTorch. In this example, we will process a batch of 10 samples.

import torch
import torch.nn as nn

# Create a linear model: input feature size 1, output feature size 1
linear_model = nn.Linear(1, 1)

# Create a batch of inputs, size (10, 1)
x = torch.ones(10, 1)

# Pass the batch through the model
output = linear_model(x)
print(output)
        

The input tensor x has a shape of [10, 1], which means we are passing a batch of 10 samples, each with 1 feature.

When we pass this tensor through linear_model, PyTorch processes all 10 inputs simultaneously, leveraging the parallel processing capabilities of the GPU. The output will have the same shape, [10, 1], since we are mapping from 1 input feature to 1 output feature for each sample in the batch.


4. Example with unsqueeze and Reshaping

When working with 1D tensors, such as temperature data, we often need to reshape them to meet the requirements of nn.Linear, which expects inputs to be of the form [B, Nin]. Let's look at how to reshape data using unsqueeze.

# Original 1D tensors
t_c = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]
t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]

# Convert to tensors and reshape using unsqueeze
t_c = torch.tensor(t_c).unsqueeze(1)  # Reshape to [11, 1]
t_u = torch.tensor(t_u).unsqueeze(1)  # Reshape to [11, 1]

# Check the shape
print(t_c.shape)  # Output: torch.Size([11, 1])
        

The unsqueeze(1) method adds an extra dimension to each tensor, transforming them from 1D tensors of shape [11] into 2D tensors of shape [11, 1].

This reshaping is necessary because nn.Linear expects a 2D input with the shape [B, Nin], where B is the batch size (11 in this case) and Nin is the number of input features per sample (1 here).


5. Batch of Images Example

In the case of image data, the input tensor typically has the shape [B, C, H, W], where:

  • B: Batch size (number of images)
  • C: Number of channels (3 for RGB images)
  • H: Height of the image
  • W: Width of the image

For example, if we have 3 RGB images of size 64x64 pixels, the input tensor would have the shape [3, 3, 64, 64]. This allows us to process a batch of images at once.

6. Summary

  • Batch Processing: Allows multiple samples to be processed simultaneously, making full use of GPU resources for faster computation.
  • Reshaping Input: When using nn.Linear, the input must have the shape [B, Nin], where B is the batch size and Nin is the number of features per sample.
  • Efficient Computation: By using batches, GPUs are fully utilized, and models can train and infer much faster than processing inputs one at a time.

Batch processing is a crucial concept for training and deploying machine learning models efficiently, and PyTorch provides the necessary tools to handle batched inputs easily. Understanding how to reshape data and utilize batching properly will help you make the most of your models, especially when working with large datasets and GPUs.


Further Reading


People are good at skipping over material they already know!

View Related Topics to







Contact Us

Name

Email *

Message *

Popular Posts

Simulation of ASK, FSK, and PSK using MATLAB Simulink (with Online Simulator)

📘 Overview 🧮 How to use MATLAB Simulink 🧮 Simulation of ASK using MATLAB Simulink 🧮 Simulation of FSK using MATLAB Simulink 🧮 Simulation of PSK using MATLAB Simulink 🧮 Simulator for ASK, FSK, and PSK 🧮 Digital Signal Processing Simulator 📚 Further Reading ASK, FSK & PSK HomePage MATLAB Simulation Simulation of Amplitude Shift Keying (ASK) using MATLAB Simulink In Simulink, we pick different components/elements from MATLAB Simulink Library. Then we connect the components and perform a particular operation. Result A sine wave source, a pulse generator, a product block, a mux, and a scope are shown in the diagram above. The pulse generator generates the '1' and '0' bit sequences. Sine wave sources produce a specific amplitude and frequency. The scope displays the modulated signal as well as the original bit sequence created by the pulse generator. Mux i...

BER vs SNR for M-ary QAM, M-ary PSK, QPSK, BPSK, ...(MATLAB Code + Simulator)

Bit Error Rate (BER) & SNR Guide Analyze communication system performance with our interactive simulators and MATLAB tools. 📘 Theory 🧮 Simulators 💻 MATLAB Code 📚 Resources BER Definition SNR Formula BER Calculator MATLAB Comparison 📂 Explore M-ary QAM, PSK, and QPSK Topics ▼ 🧮 Constellation Simulator: M-ary QAM 🧮 Constellation Simulator: M-ary PSK 🧮 BER calculation for ASK, FSK, and PSK 🧮 Approaches to BER vs SNR What is Bit Error Rate (BER)? The BER indicates how many corrupted bits are received compared to the total number of bits sent. It is the primary figure of merit for a...

Antenna Gain-Combining Methods - EGC, MRC, SC, and RMSGC

📘 Overview 🧮 Equal gain combining (EGC) 🧮 Maximum ratio combining (MRC) 🧮 Selective combining (SC) 🧮 Root mean square gain combining (RMSGC) 🧮 Zero-Forcing (ZF) Combining 🧮 MATLAB Code 📚 Further Reading  There are different antenna gain-combining methods. They are as follows. 1. Equal gain combining (EGC) 2. Maximum ratio combining (MRC) 3. Selective combining (SC) 4. Root mean square gain combining (RMSGC) 5. Zero-Forcing (ZF) Combining  1. Equal gain combining method Equal Gain Combining (EGC) is a diversity combining technique in which the receiver aligns the phase of the received signals from multiple antennas (or channels) but gives them equal amplitude weight before summing. This means each received signal is phase-corrected to be coherent with others, but no scaling is applied based on signal strength or channel quality (unlike MRC). Mathematically, for received signa...

Constellation Diagrams of ASK, PSK, and FSK (with MATLAB Code + Simulator)

Constellation Diagrams: ASK, FSK, and PSK Comprehensive guide to signal space representation, including interactive simulators and MATLAB implementations. 📘 Overview 🧮 Simulator ⚖️ Theory 📚 Resources Definitions Constellation Tool Key Points MATLAB Code 📂 Other Topics: M-ary PSK & QAM Diagrams ▼ 🧮 Simulator for M-ary PSK Constellation 🧮 Simulator for M-ary QAM Constellation BASK (Binary ASK) Modulation Transmits one of two signals: 0 or -√Eb, where Eb​ is the energy per bit. These signals represent binary 0 and 1. BFSK (Binary FSK) Modulation Transmits one ...

Coherence Bandwidth and Coherence Time (with MATLAB + Simulator)

🧮 Coherence Bandwidth 🧮 Coherence Time 🧮 MATLAB Code s 📚 Further Reading For Doppler Delay or Multi-path Delay Coherence time T coh ∝ 1 / v max (For slow fading, coherence time T coh is greater than the signaling interval.) Coherence bandwidth W coh ∝ 1 / Ï„ max (For frequency-flat fading, coherence bandwidth W coh is greater than the signaling bandwidth.) Where: T coh = coherence time W coh = coherence bandwidth v max = maximum Doppler frequency (or maximum Doppler shift) Ï„ max = maximum excess delay (maximum time delay spread) Notes: The notation v max −1 and Ï„ max −1 indicate inverse proportionality. Doppler spread refers to the range of frequency shifts caused by relative motion, determining T coh . Delay spread (or multipath delay spread) determines W coh . Frequency-flat fading occurs when W coh is greater than the signaling bandwidth. Coherence Bandwidth Coherence bandwidth is...

OFDM Symbols and Subcarriers Explained

This article explains how OFDM (Orthogonal Frequency Division Multiplexing) symbols and subcarriers work. It covers modulation, mapping symbols to subcarriers, subcarrier frequency spacing, IFFT synthesis, cyclic prefix, and transmission. Step 1: Modulation First, modulate the input bitstream. For example, with 16-QAM , each group of 4 bits maps to one QAM symbol. Suppose we generate a sequence of QAM symbols: s0, s1, s2, s3, s4, s5, …, s63 Step 2: Mapping Symbols to Subcarriers Assume N sub = 8 subcarriers. Each OFDM symbol in the frequency domain contains 8 QAM symbols (one per subcarrier): Mapping (example) OFDM symbol 1 → s0, s1, s2, s3, s4, s5, s6, s7 OFDM symbol 2 → s8, s9, s10, s11, s12, s13, s14, s15 … OFDM sym...

BER performance of QPSK with BPSK, 4-QAM, 16-QAM, 64-QAM, 256-QAM, etc (MATLAB + Simulator)

📘 Overview 📚 QPSK vs BPSK and QAM: A Comparison of Modulation Schemes in Wireless Communication 📚 Real-World Example 🧮 MATLAB Code 📚 Further Reading   QPSK provides twice the data rate compared to BPSK. However, the bit error rate (BER) is approximately the same as BPSK at low SNR values when gray coding is used. On the other hand, QPSK exhibits similar spectral efficiency to 4-QAM and 16-QAM under low SNR conditions. In very noisy channels, QPSK can sometimes achieve better spectral efficiency than 4-QAM or 16-QAM. In practical wireless communication scenarios, QPSK is commonly used along with QAM techniques, especially where adaptive modulation is applied. Modulation Bits/Symbol Points in Constellation Usage Notes BPSK 1 2 Very robust, used in weak signals QPSK 2 4 Balanced speed & reliability 4-QAM ...

ASK, FSK, and PSK (with MATLAB + Online Simulator)

📘 ASK Theory 📘 FSK Theory 📘 PSK Theory 📊 Comparison 🧮 MATLAB Codes 🎮 Simulator ASK or OFF ON Keying ASK is a simple (less complex) Digital Modulation Scheme where we vary the modulation signal's amplitude or voltage by the message signal's amplitude or voltage. We select two levels (two different voltage levels) for transmitting modulated message signals. Example: "+5 Volt" (upper level) and "0 Volt" (lower level). To transmit binary bit "1", the transmitter sends "+5 Volts", and for bit "0", it sends no power. The receiver uses filters to detect whether a binary "1" or "0" was transmitted. Fig 1: Output of ASK, FSK, and PSK modulation using MATLAB for a data stream "1 1 0 0 1 0 1 0" ( Get MATLAB Code ) ...