How Neural Networks Learn
Neural networks start with random weights and learn by adjusting these weights to transform input features into target outputs. Training involves feeding input data, computing errors, and updating weights to minimize the error. Example: predicting cereal calories from sugar, fiber, and protein.
1. Loss Function
The loss function measures the difference between predicted and true values. Common loss functions for regression include:
- MAE (Mean Absolute Error): Average absolute difference between predicted and true values.
- MSE (Mean Squared Error) and Huber Loss are alternatives.
The network uses the loss function to guide weight updates.
2. Optimizer (SGD / Adam)
Optimizers adjust the weights to minimize the loss. Steps for Stochastic Gradient Descent (SGD):
- Sample a minibatch of training data.
- Run the network to make predictions.
- Calculate the loss and adjust weights to reduce it.
Repeat over many minibatches and epochs. Parameters affecting training include:
- Learning rate: step size for weight updates.
- Batch size: number of examples per minibatch.
Adam optimizer is an adaptive version of SGD that often works well without tuning.
3. Compiling a Keras Model
model.compile(
optimizer='adam',
loss='mae',
)
4. Example: Red Wine Quality Dataset
Inputs: 11 physiochemical features of wines. Output: quality rating. Neural network architecture:
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential([
layers.Dense(512, activation='relu', input_shape=[11]),
layers.Dense(512, activation='relu'),
layers.Dense(512, activation='relu'),
layers.Dense(1),
])
Training parameters:
history = model.fit(
X_train, y_train,
validation_data=(X_valid, y_valid),
batch_size=256,
epochs=10,
)
The loss decreases over epochs, showing the network is learning. Plotting the loss over time indicates convergence when the curve flattens.
Summary
Neural networks learn by iteratively adjusting weights using an optimizer to minimize a loss function. Training proceeds over minibatches and multiple epochs. Monitoring loss helps determine when the network has learned sufficiently.