Skip to main content

PyTorch Tabular Data Classification (Personality Type Classification)

 

In this article, we will explore how to classify categories from tabular data stored in a .csv file using a neural network built with PyTorch. Suppose you're given a dataset where each row corresponds to an instance, and it includes both numerical features and a target class label such as Class1, Class2, or Class3. Your task is to train a model that can predict the correct class based on the input features.

In our example, the target classes are Introvert, Extrovert, and Ambivert, and the dataset contains 29 other columns representing various input features. We aim to build a classification model using a feedforward neural network in PyTorch. This includes defining multiple layers, selecting an appropriate loss function (e.g., CrossEntropyLoss), and optimizing the model using techniques like the Adam optimizer to improve accuracy.

In the field of machine learning (ML) and deep learning (DL), machines are particularly good at detecting patterns in data. While convolutional neural networks (CNNs) are commonly used for tasks like image recognition, fully connected neural networks are well-suited for tabular classification tasks. These models can learn complex relationships in data and make predictions on abstract categories, such as sentiment, user behavior, or personality type.

In this tutorial, we will use PyTorch to build and train a neural network that classifies individuals into one of the three personality types: Introvert, Extrovert, or Ambivert.

The code is simple and comes with a .ipynb file (Jupyter Notebook) and a dataset so you can start from scratch. 

 

Steps to Run the Code

If you are using Google Colab:

  • 1. Open the .ipynb file in Google Colab.
  • 2. Upload the .zip file containing the dataset.
  • 3. Run the code cells sequentially.
  • 4. Test with your own image or data to verify whether the model is working.

If you are using Jupyter Notebook locally:

  • 1. If not already installed, install Jupyter Notebook using the command:
    pip install jupyter notebook
  • 2. Open the notebook using the command:
    jupyter notebook
  • 3. Run each cell one by one to execute the code.

 

Code for personality-type classification


import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Load your dataset
# Replace this path with your own file or URL
my_df = pd.read_csv('personality_synthetic_dataset.csv')

# Split features and target
X = my_df.drop('personality_type', axis=1).values   # 29 features
y = my_df['personality_type'].values                # target

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=10
)

# Convert to PyTorch tensors
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)

# Label encode the target
label_encoder = LabelEncoder()
y_train = torch.LongTensor(label_encoder.fit_transform(y_train))
y_test = torch.LongTensor(label_encoder.transform(y_test))

# Normalize the inputs
X_train_mean = X_train.mean(dim=0)
X_train_std = X_train.std(dim=0)
X_train = (X_train - X_train_mean) / X_train_std
X_test = (X_test - X_train_mean) / X_train_std

# Define the neural network model
class Model(nn.Module):
    def __init__(self, in_features=29, h1=64, h2=32, h3=16, out_features=3):
        super().__init__()
        self.fc1 = nn.Linear(in_features, h1)
        self.fc2 = nn.Linear(h1, h2)
        self.fc3 = nn.Linear(h2, h3)
        self.fc4 = nn.Linear(h3, out_features)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

# Instantiate model
model = Model()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train loop
epochs = 100
losses = []

for epoch in range(epochs):
    optimizer.zero_grad()
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)
    loss.backward()
    optimizer.step()

    losses.append(loss.item())

    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Loss = {loss.item()}')
    

View Full Code on GitHub

Further Reading

  1.  

People are good at skipping over material they already know!

View Related Topics to







Admin & Author: Salim

s

  Website: www.salimwireless.com
  Interests: Signal Processing, Telecommunication, 5G Technology, Present & Future Wireless Technologies, Digital Signal Processing, Computer Networks, Millimeter Wave Band Channel, Web Development
  Seeking an opportunity in the Teaching or Electronics & Telecommunication domains.
  Possess M.Tech in Electronic Communication Systems.


Contact Us

Name

Email *

Message *

Popular Posts

BER vs SNR for M-ary QAM, M-ary PSK, QPSK, BPSK, ...

📘 Overview of BER and SNR 🧮 Simulator for BER calculation of m-ary QAM and m-ary PSK 🧮 Simulator for Constellation Diagram of m-ary QAM 🧮 Simulator for Constellation Diagram of m-ary PSK 🧮 MATLAB Code for BER calculation of M-ary QAM, M-ary PSK, QPSK, BPSK, ... 🧮 MATLAB Code for BER calculation of ASK, FSK, and PSK 🧮 MATLAB Code for BER calculation of Alamouti Scheme 🧮 Different approaches to calculate BER vs SNR 📚 Further Reading Modulation Constellation Diagrams BER vs. SNR BER vs SNR for M-QAM, M-PSK, QPSk, BPSK, ... What is Bit Error Rate (BER)? The abbreviation BER stands for bit error rate, which indicates how many corrupted bits are received (after the demodulation process) compared to the total number of bits sent in a communication process. It is defined as,  In mathematics, BER = (number of bits received in error / total number of transmitted bits)...

MATLAB Code for Pulse Amplitude Modulation (PAM) and Demodulation

📘 Overview & Theory 🧮 MATLAB Code 1 🧮 MATLAB Code 2 🧮 MATLAB Code for Pulse Amplitude Modulation and Demodulation of Digital data 🧮 Other Pulse Modulation Techniques (e.g., PWM, PPM, DM, and PCM) 📚 Further Reading   Pulse Amplitude Modulation (PAM) & Demodulation MATLAB Script clc; clear all; close all; fm= 10; % frequency of the message signal fc= 100; % frequency of the carrier signal fs=1000*fm; % (=100KHz) sampling frequency (where 1000 is the upsampling factor) t=0:1/fs:1; % sampling rate of (1/fs = 100 kHz) m=1*cos(2*pi*fm*t); % Message signal with period 2*pi*fm (sinusoidal wave signal) c=0.5*square(2*pi*fc*t)+0.5; % square wave with period 2*pi*fc s=m.*c; % modulated signal (multiplication of element by element) subplot(4,1,1); plot(t,m); title('Message signal'); xlabel ('Time'); ylabel('Amplitude'); subplot(4,1,2); plot(t,c); title('Carrier signal'); xlabel('Time'); ylabel('Amplitu...

Constellation Diagrams of ASK, PSK, and FSK

📘 Overview 🧮 Simulator for constellation diagrams of ASK, FSK, and PSK 🧮 Theory 🧮 MATLAB Codes 🧮 Simulator for constellation diagrams of m-ary PSK 🧮 Simulator for constellation diagrams of m-ary QAM 📚 Further Reading BASK (Binary ASK) Modulation: Transmits one of two signals: 0 or -√Eb, where Eb​ is the energy per bit. These signals represent binary 0 and 1.    BFSK (Binary FSK) Modulation: Transmits one of two signals: +√Eb​ ( On the y-axis, the phase shift of 90 degrees with respect to the x-axis, which is also termed phase offset ) or √Eb (on x-axis), where Eb​ is the energy per bit. These signals represent binary 0 and 1.  BPSK (Binary PSK) Modulation: Transmits one of two signals: +√Eb​ or -√Eb (they differ by 180 degree phase shift), where Eb​ is the energy per bit. These signals represent binary 0 and 1.    Simulator for BASK, BPSK, and BFSK Constellation Diagrams ...

Comparisons among ASK, PSK, and FSK | And the definitions of each

https://www.salimwireless.com/2024/11/constellation-diagram-in-matlab.html 📘 Overview 🧮 Simulator 🧮 Noise Sensitivity, Bandwidth, Complexity, etc. 🧮 MATLAB Code for BER vs. SNR Analysis of ASK, FSK, and PSK 🧮 MATLAB Code for Constellation Diagrams of ASK, FSK, and PSK 🧮 Simulator for ASK, FSK, and PSK Generation 🧮 Simulator for ASK, FSK, and PSK Constellation 🧮 Some Questions and Answers 📚 Further Reading Modulation ASK, FSK & PSK Constellation MATLAB Simulink MATLAB Code Comparisons among ASK, PSK, and FSK    Comparisons among ASK, PSK, and FSK   Simulator for Calculating Bandwidth of ASK, FSK, and PSK The baud rate represents the number of symbols transmitted per second. Both baud rate and bit rate are same for binary ASK, FSK, and PSK. Select Modulation Type: ASK FSK PSK Baud Rat...

Channel Impulse Response (CIR)

Channel Impulse Response (CIR) 📘 Overview & Theory 📘 How does the channel impulse response affect the signal? 🧮 Online Channel Impulse Response Simulator 🧮 MATLAB Codes 📚 Further Reading Wireless Signal Processing CIR, Doppler Shift & Gaussian Random Variable  The Channel Impulse Response (CIR) is a concept primarily used in the field of telecommunications and signal processing. It provides information about how a communication channel responds to an impulse signal.   What is the Channel Impulse Response (CIR) ? It describes the behavior of a communication channel in response to an impulse signal. In signal processing,  an impulse signal has zero amplitude at all other times and amplitude  ∞ at time 0 for the signal. Using a Dirac Delta function, we can approximate this.  ...(i) δ( t) now has a very intriguing characteristic. The answer is 1 when the Fourier Transform of  δ(...

UGC NET Electronic Science Previous Year Question Papers

Home / Engineering & Other Exams / UGC NET 2022: Previous Year Question Papers ...   NET | GATE | ESE | UGC-NET (Electronics Science, Subject code: 88 ) UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [December 2024] UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [June 2024] UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [December 2023] UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [June 2023] UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [December 2022]  UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [June 2022]   UGC Net Electronic Science Questions Paper With Answer Key Download Pdf [December 2021] UGC Net Electronic Science Questions With Answer Key Download Pdf [June 2020] UGC Net Electronic Science Questions With Answer Key Download Pdf [December 2019] UGC Net Electronic Science Questions With Answer...

MATLAB Code for Constellation Diagram of QAM configurations such as 4, 8, 16, 32, 64, 128, and 256-QAM

📘 Overview of QAM 🧮 MATLAB Code for 4-QAM 🧮 MATLAB Code for 16-QAM 🧮 MATLAB Code for m-ary QAM (4-QAM, 16-QAM, 32-QAM, ...) 📚 Further Reading   One of the best-performing modulation techniques is QAM [↗] . Here, we modulate the symbols by varying the carrier signal's amplitude and phase in response to the variation in the message signal (or voltage variation). So, we may say that QAM is a combination of phase and amplitude modulation. Additionally, it performs better than ASK or PSK [↗] . In fact, any constellation for any type of modulation, signal set (or, symbols) is structured in a way that prevents them from interacting further by being distinct by phase, amplitude, or frequency. MATLAB Script (for 4-QAM) % This code is written by SalimWirelss.Com % This is an example of 4-QAM. Here constellation size is 4 % or total number of symbols/signals is 4 % We need 2 bits once to represent four constellation points % QAM modulation is the combina...

MATLAB Code for Constellation Diagrams of ASK, FSK, and PSK

  MATLAB Script % The code is developed by SalimWireless.Com clc; clear; close all; % Parameters numSymbols = 1000; % Number of symbols to simulate symbolIndices = randi([0 1], numSymbols, 1); % Random binary symbols (0 or 1) % ASK Modulation (BASK) askAmplitude = [0, 1]; % Amplitudes for binary ASK askSymbols = askAmplitude(symbolIndices + 1); % Modulated BASK symbols % FSK Modulation (Modified BFSK with 90-degree offset) fs = 100; % Sampling frequency symbolDuration = 1; % Symbol duration in seconds t = linspace(0, symbolDuration, fs*symbolDuration); fBase = 1; % Base frequency frequencies = [fBase, fBase]; % Same frequency for both % Generate FSK symbols with 90° phase offset fskSymbols = arrayfun(@(idx) ...     cos(2*pi*frequencies(1)*t) * (1-idx) + ...     1j * cos(2*pi*frequencies(2)*t) * idx, ...     symbolIndices, 'UniformOutput', false); % Extract last points (constellation points) fskConstellation = cellfun(@(x) x(end), fskSymbols); % PSK Mod...