Preliminaries on Partial Derivatives
Suppose a scalar variable depend on some variables , we write as the partial derivatives of . We stress that the convention here is that has exactly the same dimension as itself. For example, if , then , and the -entry of is equal to .
Chain rule
Consider a scalar variable which is obtained by the composition of and on some variable
Let and let , then the standard chain rule gives us that
or in a vectorized notation
In other words, the backward function is always a linear map from to .
Key interpretation of the chain rule
We can view the formula above as a way to compute from
Moreover, this formula only involves knowledge about (more precisely ).
We use to define the function that maps to , and write
We call the backward function for the module . Note that when is fixed, is merely a linear map from to
General strategy of back-propagation
We take the viewpoint that neural networks are complex compositions of small building blocks such as MM, , Conv2D, LN etc., then we can abstractly write the loss function (on a single example ) as a composition of many modules
We assume that each involves a set of parameters , though could possibly be an empty set when is a fixed operation such as the non-linear activations.
We introduce the intermediate variables for the composition
Back-propagation consists of two passes. In the forward pass, the algorithm simply computes from , and save all the intermediate variables 's in the memory.
In the backward pass, we first compute the derivatives w.r.t to the intermediate variables, that is, , sequentially in this backward order, and then compute the derivatives of the parameters form and . These two type of computations can be also interleaved with each other because only depends on and .
We first see why can be computed efficiently from and by invoking the discussion on the chain rule. We instantiate the discussion by setting and , and , and . Note that is very complex but we don't need any concrete information about . Then, the conclusive equation corresponds to
More precisely, we can write
Instantiating the chain rule with and , we also have
Example Code
#include <vector>
#include <memory>
// Base class for a module
class Module {
public:
virtual ~Module() = default;
// Forward pass: computes output given input
virtual std::vector<double> forward(const std::vector<double>& input) = 0;
// Backward pass: computes gradients of input and parameters
virtual std::vector<double> backward(const std::vector<double>& grad_output) = 0;
// Update parameters using gradients
virtual void update_parameters(double learning_rate) = 0;
};
// Neural network class
class NeuralNetwork {
private:
std::vector<std::shared_ptr<Module>> modules;
std::vector<std::vector<double>> intermediate_outputs;
public:
void add_module(std::shared_ptr<Module> module) {
modules.push_back(module);
}
// Forward pass
std::vector<double> forward(const std::vector<double>& input) {
intermediate_outputs.clear();
std::vector<double> current_output = input;
intermediate_outputs.push_back(current_output); // Save input
for (const auto& module : modules) {
current_output = module->forward(current_output);
intermediate_outputs.push_back(current_output); // Save intermediate output
}
return current_output; // Final output (loss)
}
// Backward pass
void backward(const std::vector<double>& grad_loss) {
std::vector<double> current_grad = grad_loss;
for (int i = modules.size() - 1; i >= 0; --i) {
current_grad = modules[i]->backward(current_grad);
}
}
// Update parameters
void update_parameters(double learning_rate) {
for (const auto& module : modules) {
module->update_parameters(learning_rate);
}
}
};
// Example usage
int main() {
NeuralNetwork network;
// Add modules to the network (e.g., Linear, ReLU, etc.)
// network.add_module(std::make_shared<Linear>(...));
// network.add_module(std::make_shared<ReLU>());
// Forward pass
std::vector<double> input = { /* input data */ };
std::vector<double> output = network.forward(input);
// Compute loss gradient (e.g., using a loss function)
std::vector<double> grad_loss = { /* gradient of loss w.r.t. output */ };
// Backward pass
network.backward(grad_loss);
// Update parameters
double learning_rate = 0.01;
network.update_parameters(learning_rate);
return 0;
}
#include <vector>
#include <memory>
#include <random>
#include <iostream>
#include <stdexcept>
class Linear : public Module {
private:
std::vector<std::vector<double>> weights; // Weight matrix (W)
std::vector<double> bias; // Bias vector (b)
std::vector<double> input; // Saved input for backward pass
std::vector<std::vector<double>> grad_weights; // Gradient of weights
std::vector<double> grad_bias; // Gradient of bias
public:
// Constructor: Initializes weights and biases randomly
Linear(int input_size, int output_size) {
// Initialize weights and biases with random values
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<double> dist(0.0, 0.01);
weights.resize(output_size, std::vector<double>(input_size));
for (int i = 0; i < output_size; ++i) {
for (int j = 0; j < input_size; ++j) {
weights[i][j] = dist(gen);
}
}
bias.resize(output_size);
for (int i = 0; i < output_size; ++i) {
bias[i] = dist(gen);
}
// Initialize gradients to zero
grad_weights.resize(output_size, std::vector<double>(input_size, 0.0));
grad_bias.resize(output_size, 0.0);
}
// Forward pass: Computes y = Wx + b
std::vector<double> forward(const std::vector<double>& input) override {
if (input.size() != weights[0].size()) {
throw std::invalid_argument("Input size does not match weight matrix dimensions.");
}
this->input = input; // Save input for backward pass
std::vector<double> output(weights.size(), 0.0);
for (int i = 0; i < weights.size(); ++i) {
for (int j = 0; j < input.size(); ++j) {
output[i] += weights[i][j] * input[j];
}
output[i] += bias[i];
}
return output;
}
// Backward pass: Computes gradients of input, weights, and bias
std::vector<double> backward(const std::vector<double>& grad_output) override {
if (grad_output.size() != weights.size()) {
throw std::invalid_argument("Gradient output size does not match weight matrix dimensions.");
}
std::vector<double> grad_input(weights[0].size(), 0.0);
// Compute gradient of input
for (int j = 0; j < weights[0].size(); ++j) {
for (int i = 0; i < weights.size(); ++i) {
grad_input[j] += weights[i][j] * grad_output[i];
}
}
// Compute gradient of weights
for (int i = 0; i < weights.size(); ++i) {
for (int j = 0; j < weights[0].size(); ++j) {
grad_weights[i][j] += grad_output[i] * input[j];
}
}
// Compute gradient of bias
for (int i = 0; i < weights.size(); ++i) {
grad_bias[i] += grad_output[i];
}
return grad_input;
}
// Update parameters using gradients and learning rate
void update_parameters(double learning_rate) override {
for (int i = 0; i < weights.size(); ++i) {
for (int j = 0; j < weights[0].size(); ++j) {
weights[i][j] -= learning_rate * grad_weights[i][j];
}
}
for (int i = 0; i < bias.size(); ++i) {
bias[i] -= learning_rate * grad_bias[i];
}
// Reset gradients to zero
for (int i = 0; i < weights.size(); ++i) {
for (int j = 0; j < weights[0].size(); ++j) {
grad_weights[i][j] = 0.0;
}
}
for (int i = 0; i < bias.size(); ++i) {
grad_bias[i] = 0.0;
}
}
// Utility function to print weights and biases
void print_parameters() const {
std::cout << "Weights:" << std::endl;
for (const auto& row : weights) {
for (double val : row) {
std::cout << val << " ";
}
std::cout << std::endl;
}
std::cout << "Biases:" << std::endl;
for (double val : bias) {
std::cout << val << " ";
}
std::cout << std::endl;
}
};
// Example usage
int main() {
Linear linear_layer(3, 2); // Input size = 3, Output size = 2
std::vector<double> input = {1.0, 2.0, 3.0};
// Forward pass
std::vector<double> output = linear_layer.forward(input);
std::cout << "Output:" << std::endl;
for (double val : output) {
std::cout << val << " ";
}
std::cout << std::endl;
// Backward pass (dummy gradient)
std::vector<double> grad_output = {0.1, 0.2};
linear_layer.backward(grad_output);
// Update parameters
linear_layer.update_parameters(0.01);
// Print updated parameters
linear_layer.print_parameters();
return 0;
}