11  Solutions

Solutions for the exercises.

11.1 Tensor

11.1.1 Solution 2.1

The input data can either be

  • Tensor
  • NDArray
  • List/Scalars

It doesn’t matter what the input data is , we want to make sure self.data is always NDArray(np.ndarray).

FILE : babygrad/tensor.py


import numpy as np
NDArray = np.ndarray
class Tensor:
    def __init__(self, data, *, device=None, dtype="float32", 
        requires_grad=True):
        if isinstance(data, Tensor):
            if dtype is None:
                dtype = data.dtype 
            # Get the ndarray of the Tensor by calling .numpy()
            self.data = data.numpy().astype(dtype)
        elif isinstance(data, np.ndarray):
            # data is already ndarray
            self.data = data.astype(dtype if dtype is not None
             else data.dtype)
        else:
            #converting to ndarray 
            self.data = np.array(data, dtype=dtype if dtype is not None 
            else "float32")
        self.grad = None 
        self.requires_grad = requires_grad
        self._op =None 
        self._inputs = []
        self._device = device if device else "cpu"

11.1.2 Final Code

The final babygrad/tensor.py looks like this.

import numpy as np
NDArray = np.ndarray
def _ensure_tensor(val):
    return val if isinstance(val, Tensor) else Tensor(val,
     requires_grad=False)
class Tensor:
    def __init__(self, data, *, device=None, dtype="float32",
     requires_grad=False):
        if isinstance(data, Tensor):
            if dtype is None:
                dtype = data.dtype
            self.data = data.numpy().astype(dtype)
        elif isinstance(data, np.ndarray):
            self.data = data.astype(dtype if dtype is not None else data.dtype)
        else:
            self.data = np.array(data, dtype=dtype if dtype is not None
             else "float32")
        self.grad = None
        self.requires_grad = requires_grad
        self._op = None       
        self._inputs = []     
        self._device = device if device else  "cpu"    
    def numpy(self):
        return self.data.copy()
    def detach(self):
        return Tensor(self.data, requires_grad=False, dtype=str(self.dtype))
    @property
    def shape(self):
        return self.data.shape
    @property
    def dtype(self):
        return self.data.dtype
    @property
    def ndim(self):
        return self.data.ndim
    @property
    def size(self):
        return self.data.size    
    @property
    def device(self):
        return self._device    
    @property
    def T(self):
        return self.transpose()    
    def __repr__(self):
        return f"Tensor({self.data}, requires_grad={self.requires_grad})"    
    def __str__(self):
        return str(self.data)
    @classmethod
    def rand(cls, *shape, low=0.0, high=1.0, dtype="float32", 
        requires_grad=True):
        array = np.random.rand(*shape) * (high - low) + low
        return cls(array.astype(dtype), requires_grad=requires_grad)
    @classmethod
    def randn(cls, *shape, mean=0.0, std=1.0, dtype="float32",
         requires_grad=True):
        array = np.random.randn(*shape) * std + mean
        return cls(array.astype(dtype), requires_grad=requires_grad)
    @classmethod
    def constant(cls, *shape, c=1.0, dtype="float32", requires_grad=True):
        array = np.ones(*shape) * c
        return cls(array.astype(dtype), requires_grad=requires_grad)
    @classmethod
    def ones(cls, *shape, dtype="float32", requires_grad=True):
        return cls.constant(*shape, c=1.0, dtype=dtype, 
        requires_grad=requires_grad)
    @classmethod
    def zeros(cls, *shape, dtype="float32", requires_grad=True):
        return cls.constant(*shape, c=0.0, dtype=dtype,
         requires_grad=requires_grad)
    @classmethod 
    def randb(cls, *shape, p=0.5, dtype="float32", requires_grad=True):
        array =np.random.rand(*shape) <=p 
        return cls(array,dtype=dtype, requires_grad=requires_grad )
    @classmethod 
    def empty(cls, *shape, dtype="float32", requires_grad=True):
        array =np.empty(shape,dtype=dtype)
        return cls(array, requires_grad=requires_grad)
    
    @classmethod 
    def one_hot(cls,indices,num_classes,device=None, dtype="float32",
         requires_grad=True):
        one_hot_array = np.eye(num_classes,dtype=dtype)[np.array(
            indices.data,dtype=int)]
        return cls(one_hot_array,device=device, dtype=dtype,
            requires_grad=requires_grad)
        

11.2 Automatic Differentiation

We will do both forward and backward pass here

11.2.1 Basic

Power Operation When we compute \(a^b\), we have two cases for the derivative:

  • With respect to the base (\(a\)): \(\frac{\partial}{\partial a}(a^b) = b \cdot a^{b-1}\)
  • With respect to the exponent (\(b\)): \(\frac{\partial}{\partial b}(a^b) = a^b \cdot \ln(a)\).

Division Operation

Division \(a/b\) can be thought of as \(a \cdot b^{-1}\).

  • With respect to \(a\): \(\frac{\partial}{\partial a}(\frac{a}{b}) = \frac{1}{b}\).
  • With respect to \(b\): \(\frac{\partial}{\partial b}(\frac{a}{b}) = -\frac{a}{b^2}\).
class Pow(Function):
    def forward(self, a, b):
        return np.power(a, b)    
    def backward(self, out_grad, node):
        a, b = node._inputs
        grad_a = multiply(multiply(out_grad, b), power(a, add_scalar(b, -1)))
        grad_b = multiply(multiply(out_grad, power(a, b)), log(a))
        return grad_a, grad_b
def power(a, b):
    return Pow()(a, b)
class PowerScalar(Function):
    def __init__(self, scalar: int):
        self.scalar = scalar
    def forward(self, a: NDArray) -> NDArray:
        return np.power(a ,self.scalar)
    def backward(self, out_grad, node):
        inp = node._inputs[0]
        grad = multiply(out_grad, multiply(Tensor(self.scalar),
         power_scalar(inp, self.scalar - 1)))
        return grad
def power_scalar(a, scalar):
    return PowerScalar(scalar)(a)
class Div(Function):
    def forward(self, a, b):
        return a/b
    def backward(self, out_grad, node):
        x,y = node._inputs 
        grad_x = divide(out_grad, y)
        grad_y = multiply(negate(out_grad), divide(x, multiply(y, y)))
        return grad_x, grad_y
def divide(a, b):
    return Div()(a, b)
class DivScalar(Function):
    def __init__(self, scalar):
        self.scalar = scalar
    def forward(self, a):
        return np.array(a / self.scalar, dtype=a.dtype)
    def backward(self, out_grad, node):
        return  out_grad/self.scalar
def divide_scalar(a, scalar):
    return DivScalar(scalar)(a)
  • Negate: \(f(x) = -x \implies f'(x) = -1\)
  • Log: \(f(x) = \ln(x) \implies f'(x) = \frac{1}{x}\)
  • Exp: \(f(x) = e^x \implies f'(x) = e^x\)
  • Sqrt: \(f(x) = \sqrt{x} \implies f'(x) = \frac{1}{2\sqrt{x}}\)
class Negate(Function):
    def forward(self, a):
        return -a
    def backward(self, out_grad, node):
        return negate(out_grad)
def negate(a):
    return Negate()(a)
class Log(Function):
    def forward(self, a):
        return np.log(a)
    def backward(self, out_grad, node):
        # f'(x) = 1/x
        inp = node._inputs[0]
        return divide(out_grad, inp)
def log(a):
    return Log()(a)
class Exp(Function):
    def forward(self, a):
        return np.exp(a)
    def backward(self, out_grad, node):
        # We already calculated exp(x) in forward, it's stored in 'node'.
        return multiply(out_grad, node) 
def exp(a):
    return Exp()(a)
class Sqrt(Function):
    def forward(self, a):
        return np.sqrt(a)
    def backward(self, out_grad, node):
        # f'(x) = 1 / (2 * sqrt(x))
        # Again, 'node' IS sqrt(x). We use it directly.
        two = Tensor(2.0)
        return divide(out_grad, multiply(two, node))
def sqrt(a):
    return Sqrt()(a)

11.2.2 Activations

  • ReLU: \(f(x) = \max(0, x)\). The gradient is \(1\) if \(x > 0\) and \(0\) otherwise.-
  • Sigmoid: \(\sigma(x) = \frac{1}{1 + e^{-x}}\). The gradient is \(\sigma(x) \cdot (1 - \sigma(x))\).
  • Tanh: \(f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\). The gradient is \(1 - \tanh^2(x)\).
class ReLU(Function):
    def forward(self, a):
        ### BEGIN YOUR SOLUTION     
        a = a * (a>0)
        return a
    def backward(self, out_grad, node):
        inp = node._inputs[0]
        mask = Tensor((inp.data > 0).astype("float32"), requires_grad=False)
        return multiply(out_grad, mask)

def relu(a):
    return ReLU()(a)

class Sigmoid(Function):
    def forward(self, a):
        out = 1/(1+np.exp(-a))
        return out 
    def backward(self, out_grad, node):
        one = Tensor(1.0, requires_grad=False)
        local_grad = multiply(node, add(one, negate(node)))
        return multiply(out_grad, local_grad)
    
def sigmoid(x):
    return Sigmoid()(x) 

class Tanh(Function):
    def forward(self,a):
        return np.tanh(a)
    def backward(self, out_grad, node):
        one = Tensor(1.0, requires_grad=False)
        squared = multiply(node, node)
        local_grad = add(one, negate(squared))
        return multiply(out_grad, local_grad)

def tanh(x):
    return Tanh()(x)

11.2.3 Reshape

If A.shape=(2,3) and if we do the Forward pass for reshape we get shape (3,2). Then out_grad.shape is (3,2). To convert back to (2,3) We will simply use reshape again.

When reshaping number of elements remain same in the forward pass so we can just reshape them back to the original.

class Reshape(Function):
    def __init__(self, shape):
        self.shape = shape
    def forward(self, a):
        return np.reshape(a,self.shape)
    def backward(self, out_grad, node):
        a = node._inputs[0]
        return reshape(out_grad, a.shape)

def reshape(a, shape):
    return Reshape(shape)(a)

11.2.4 Transpose

The 2 possible inputs cases:

  • axes is None, then we just swap the last 2 axes.
  • axes is not None, then we just use np.tranpose.

For backward pass

  • If axes is None, we call tranpose(out_grad). Just simple swap of last 2 axes.
  • We call np.argsort(self.axes) and pass that result to transpose(out_grad,result).

class Transpose(Function):
    def __init__(self, axes: Optional[tuple] = None):
        self.axes = axes
    def forward(self, a):
        if self.axes is None:
            return np.swapaxes(a, -1, -2)

        ndim = a.ndim
        #handling -ve axes 
        axes = tuple(ax if ax >= 0 else ndim + ax for ax in self.axes)
        if len(axes) == 2:
            full_axes = list(range(ndim))
            i, j = axes
            full_axes[i], full_axes[j] = full_axes[j], full_axes[i]
            self.full_axes = tuple(full_axes)
        else:
            self.full_axes = axes

        return np.transpose(a, self.full_axes)
    def backward(self, out_grad, node):
        if self.axes is None:
            return transpose(out_grad)
        inverse_axes = np.argsort(self.axes)
        return transpose(out_grad, tuple(inverse_axes))

def transpose(a, axes=None):
    return Transpose(axes)(a)

11.2.5 Summation

In the forward pass, summation is a “squashing” operation. It takes many values and reduces them into fewer values (or a single scalar).

In the backward pass, we have to do the opposite: we must spread the gradient from the smaller output back to the larger original shape. We do this in two steps:

  • Reshaping: We put “1”s back into the axes that were summed away. This aligns the gradient’s dimensions with the parent’s dimensions.
  • Broadcasting: We stretch those “1”s until they match the parent’s original size.

We have 2 cases :

  • Global Sum (axes = None) If a (3,3) matrix is summed, we get a single scalar.

    • Step 1: Reshape the scalar to (1,1).
    • Step 2: Broadcast (1,1) \(\rightarrow\) (3,3).
  • Axis-Specific Sum (axes = 0 or 1) If we sum a (3,3) matrix over axis=0, we get a shape of (3,)

    • Step 1: Reshape (3,) \(\rightarrow\) (1,3). (We put the 1 back where the axis vanished).
    • Step 2: Broadcast (1,3) \(\rightarrow\) (3,3).

Since we haven’t officially built a BroadcastTo operation yet, we can use a clever trick. Multiplying our reshaped gradient by a matrix of ones of the original shape triggers NumPy’s internal broadcasting automatically!

class Summation(Function):
    def __init__(self, axes: Optional[tuple] = None):
        self.axes = axes

    def forward(self, a):
        return np.sum(a, axis=self.axes)

    def backward(self, out_grad, node):
        a = node._inputs[0]

        original_shape = a.shape
        
        if self.axes is None:
            intermediate_shape = (1,) * len(original_shape)
        else:
            axes = self.axes if isinstance(self.axes, (list, tuple))
                     else (self.axes,)
            axes= [ax if ax >= 0 else ax + len(original_shape) for ax in axes]
            intermediate_shape = list(out_grad.shape)

            #inserting 1's where the axis was vanished.
            for ax in sorted(axes):
                intermediate_shape.insert(ax, 1)
        #reshape 
        reshaped_grad = reshape(out_grad, tuple(intermediate_shape))        
        ones = np.ones(original_shape)  
        ones = Tensor(ones)
        #broadcast or multiply by ones        
        return reshaped_grad * ones

11.2.6 BroadcastTo

For the forward pass we can simply use np.broadcast_to(). In the backward pass we just need to reverse whatever has happened during the forward pass.

There are 2 operations that happened in forward pass.

Prepending: For (3,)

If you add a vector of shape (3,) to a matrix of shape (2, 3), NumPy treats the vector as (1, 3). It prepends a new dimension to the front. That means we need to sum this new dimension (axis=0) during the backward pass..

  • If the out_grad has more dimensions than the original input, sum the out_grad** across those leading axes(axis=0)** until the number of dimensions matches.

Stretching To convert (1,3) to (3,3).

Even with the same number of dimensions, shapes might differ (e.g., converting (3, 1) to (3, 3)).

  • Locate any dimension that was originally 1 but is now N. Sum the out_grad along that specific axis.

class BroadcastTo(Function):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, a):
        return np.broadcast_to(a, self.shape)

    def backward(self, out_grad, node):
        a = node._inputs[0]
        original_shape = a.shape 
        converted_shape = out_grad.shape

        #Un-Prepending 
        changed_shape = len(converted_shape) -len(original_shape)
        grad =out_grad
        for _ in range(changed_shape):
            grad = summation(grad, axes=0)

        # Un-strectching
        for i, (orig_dim, new_dim) in enumerate(zip(original_shape, grad.shape)):
            if orig_dim ==1 and new_dim > 1 :
                grad = summation(grad, axes=i)
                new_shape = list(grad.shape)
                # numpy sometimes does (n,) instead of (n,1). 
                # We insert (1) for reshaping.
                new_shape.insert(i, 1)  
                grad = reshape(grad, tuple(new_shape))

        return grad 

11.2.7 Matmul

In the forward pass, we have the matrix product of \(A\) and \(B\):\[C = A \cdot B\] When we calculate the backward pass, we are looking for how the Scalar Loss (\(L\)) changes with respect to our inputs.

Gradient with respect to \(A\)

To find \(\frac{\partial L}{\partial A}\), we take the incoming gradient and “multiply” it by the transpose of \(B\):\[\nabla_A L = \frac{\partial L}{\partial C} \cdot B^T\]

  • Dimensions: \((M, P) @ (P, N) = (M, N)\)

Gradient with respect to \(B\)

To find \(\frac{\partial L}{\partial B}\), we “multiply” the transpose of \(A\) by the incoming gradient: \[\nabla_B L = A^T \cdot \frac{\partial L}{\partial C}\]

  • Dimensions: \((N, M) @ (M, P) = (N, P)\)
class MatMul(Function):
    def forward(self, a, b):
        return np.matmul(a, b)

    def backward(self, out_grad, node):
        a, b = node._inputs

        if len(out_grad.shape) == 0:
            out_grad = out_grad.broadcast_to(node.shape)

        grad_a = matmul(out_grad, transpose(b, axes=(-1, -2)))
        grad_b = matmul(transpose(a, axes=(-1, -2)), out_grad)

        while len(grad_a.shape) > len(a.shape):
            grad_a = summation(grad_a, axes=0)
        while len(grad_b.shape) > len(b.shape):
            grad_b = summation(grad_b, axes=0)

        grad_a = grad_a.reshape(a.shape)
        grad_b = grad_b.reshape(b.shape)
        return grad_a, grad_b
def matmul(a, b):
    return MatMul()(a, b)

11.2.8 Implementing Backward Pass in the Tensor Class

Our solution focuses on two core principles:

  • Topological Ordering: We use a Depth-First Search (DFS) to build a “walkable” list of our graph. By processing this list in reverse, we ensure that we never calculate a parent’s gradient until we have finished gathering from its children.

  • Gradient Accumulation: In the real world, a single tensor might be used by multiple operations (e.g., \(y = x^2 + x\)). In the backward pass, \(x\) must receive gradients from both paths. Our grads dictionary acting as a ledger ensures we add contributions together rather than overwriting them.

class Tensor:
    def backward(self, grad=None):
        if not self.requires_grad:
            raise RuntimeError("Cannot call backward on a tensor
             that does not require gradients.")

        # Build the "Family Tree" in order (Topological Sort)
        topo_order = []
        visited = set()
        def build_topo(node):
            if id(node) not in visited:
                visited.add(id(node))
                for parent in node._inputs:
                    build_topo(parent)
                topo_order.append(node)
        build_topo(self)
        
        # Initialize the Ledger
        grads = {}
        if grad is None:
            # The "output" gradient: dL/dL = 1
            grads[id(self)] = Tensor(np.ones_like(self.data))
        else:
            grads[id(self)] = _ensure_tensor(grad)
        
        # Walk the Graph Backwards
        for node in reversed(topo_order):
            out_grad = grads.get(id(node))
            if out_grad is None:
                continue

            # Store the final result in the .grad attribute
            if node.grad is None:
                node.grad = np.array(out_grad.data, copy=True)

            else:
                node.grad += out_grad.data

            # Propagate to Parents
            if node._op:
                input_grads = node._op.backward(out_grad, node)
                if not isinstance(input_grads, tuple):
                    input_grads = (input_grads,)
                
                for i, parent in enumerate(node._inputs):
                    if parent.requires_grad:
                        parent_id = id(parent)
                        if parent_id not in grads:
                            # First time seeing this parent
                            grads[parent_id] = input_grads[i]
                        else:
                            #  Sum the gradients!
                            grads[parent_id] = grads[parent_id] + input_grads[i]

11.3 nn

11.3.1 Basics


class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return ops.relu(x)

class Tanh(Module):
    def forward(self, x: Tensor): 
        return ops.tanh(x)

class Sigmoid(Module):
    def forward(self,x: Tensor):
        return ops.sigmoid(x)
    

11.3.2 Flatten


class Flatten(Module):
    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.shape[0]
        # Calculate the product of all dimensions except the first (batch)
        flat_dim = np.prod(x.shape[1:]).item()
        return x.reshape(batch_size, flat_dim)

11.3.3 Linear


class Linear(Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
             device: Any | None = None, dtype: str = "float32") -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(Tensor.randn(in_features, out_features))
        self.bias = None
        if bias:
            self.bias = Parameter(Tensor.zeros(1, out_features))

    def forward(self, x: Tensor) -> Tensor:
        # (bs,in) @ (in,out) -> (bs,out)
        out = x @ self.weight
        if self.bias is not None:
            # (1,out) -> (bs,out) #broadcasted
            out += self.bias.broadcast_to(out.shape)
        return out     

11.3.4 Sequential


class Sequential(Module):
    def __init__(self, *modules):
        super().__init__()
        self.modules = modules
    def forward(self, x: Tensor) -> Tensor:
        for module in self.modules:
            x = module(x)
        return x

11.3.5 Dropout


class Dropout(Module):
    def __init__(self, p: float = 0.5):
        super().__init__()
        self.p = p
    def forward(self, x: Tensor) -> Tensor:
        if self.training:  
            # if we want to dropout 20% of neurons,
            # that means 20% of the mask should be 0 and 80% should be 1 .
            mask = Tensor.randb(*x.shape, p=(1 - self.p))
            return (x * mask) / (1 - self.p)
        else:
            return x

11.3.6 LayerNorm

The solution is literally the same thing as the formula except that we need to be careful with shapes. We need to use reshape and broadcast_to in the solution.

class LayerNorm1d(Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.dim = dim 
        self.eps = eps 
        self.weight = Parameter(Tensor.ones(dim))
        self.bias = Parameter(Tensor.zeros(dim))
    
    def forward(self, x: Tensor):
        # x.shape is (batch_size, dim)
        
        # Mean
        # Sum across axis 1 to get (batch_size,)
        sum_x = ops.summation(x, axes=(1,))
        mean = sum_x / self.dim 
        
        #reshape and broadcast 
        # (batch_size,) -> (batch_size, 1) -> (batch_size, dim)
        mean_reshaped = ops.reshape(mean, (x.shape[0], 1))
        mean_broadcasted = ops.broadcast_to(mean_reshaped, x.shape)
        
        # Numerator
        x_minus_mean = x - mean_broadcasted

        var = ops.summation(x_minus_mean**2, axes=(1,)) / self.dim 
        
        # (batch_size,) -> (batch_size, 1) -> (batch_size, dim)
        var_reshaped = ops.reshape(var, (x.shape[0], 1))
        var_broadcasted = ops.broadcast_to(var_reshaped, x.shape)

        # Denominator
        std = ops.sqrt(var_broadcasted + self.eps)
        x_hat = x_minus_mean / std 

        #Reshape 
        # weight/bias from (dim,) to (1, dim) 
        weight_reshaped = ops.reshape(self.weight, (1, self.dim))
        bias_reshaped = ops.reshape(self.bias, (1, self.dim))

        weight_broadcasted = ops.broadcast_to(weight_reshaped, x.shape)
        bias_broadcasted = ops.broadcast_to(bias_reshaped, x.shape)
        
        return weight_broadcasted * x_hat + bias_broadcasted

11.3.7 BatchNorm

class BatchNorm1d(Module):
    def __init__(self, dim: int, eps: float = 1e-5, momentum: float = 0.1,
                 device: Any | None = None, dtype: str = "float32") -> None:
        super().__init__()
        self.dim = dim 
        self.eps = eps 
        self.momentum = momentum
        self.weight = Parameter(Tensor.ones(dim, dtype=dtype))
        self.bias = Parameter(Tensor.zeros(dim,  dtype=dtype))
        self.running_mean = Tensor.zeros(dim, dtype=dtype)
        self.running_var = Tensor.ones(dim, dtype=dtype)

    def forward(self, x: Tensor) -> Tensor:
        if self.training:
            # x.shape is (batch_size, dim)
            batch_size = x.shape[0] 
            # (batch_size,dim) -> (dim,)           
            mean = ops.summation(x, axes=(0,)) / batch_size
            
            #mean shape (1,dim) -> (bs,dim)
            # (bs,dim) - (bs-dim)
            var = ops.summation((x - ops.broadcast_to(mean.reshape
                        ((1, self.dim)), x.shape))**2, axes=(0,)) / batch_size

            self.running_mean.data = (1 - self.momentum) * 
                            self.running_mean.data + self.momentum * mean.data

            self.running_var.data = (1 - self.momentum) *
                            self.running_var.data + self.momentum * var.data

            mean_to_use = mean
            var_to_use = var
        else:
            mean_to_use = self.running_mean
            var_to_use = self.running_var

        # mean_to_use (dim,) -> (1,dim)
        mean_reshaped = mean_to_use.reshape((1, self.dim))

        # var_to_use (dim,) -> (1,dim)
        var_reshaped = var_to_use.reshape((1, self.dim))

        std = ops.sqrt(var_reshaped + self.eps)

        # mean_reshaped (1,dim) -> (bs,dim)/(bs,dim)

        x_hat = (x - ops.broadcast_to(mean_reshaped, x.shape))
                     / ops.broadcast_to(std, x.shape)

        # weight/bias -> (dim,) -> (1,dim)
        weight_reshaped = self.weight.reshape((1, self.dim))
        bias_reshaped = self.bias.reshape((1, self.dim))
        
        #weight/bias -> (1,dim) -> (bs,dim)
        return ops.broadcast_to(weight_reshaped, x.shape) * x_hat
                     + ops.broadcast_to(bias_reshaped, x.shape)

11.3.8 Softmax Loss

First we will implement LogSumExp inside babygrad/ops.py and then implement SoftmaxLoss.

LogSumExp

In the forward pass we will do an extra step. We will subtract maximum value from the input before exponentiating. Strictly to prevent overflowing. And then add the max value at the end. We expect our output to be of shape (batch_size,) and not (batch_size, 1)(which numpy often does). We will just remove that 2nd axis to stop broadcasting errors happening.

\[\text{LogSumExp}(x) = \max(x) + \ln \left( \sum_{i} e^{x_i - \max(x)} \right)\]

Lets figure our how to do backward pass for this.

Using the chain rule on \(f(x) = \ln(\sum e^{x_i})\), the derivative with respect to \(x_i\) is:

\[\frac{\partial \text{LSE}}{\partial x_i} = \frac{e^{x_i}}{\sum_j e^{x_j}}\]

Does it look like a Softmax function? Yes!

\[\text{grad} = out\_grad \times \text{Softmax}(x)\]

We must also be careful that during the forward pass we squeezed the last dimension. So in the backward pass we
have to unsqueeze it.


class LogSumExp(Function):
    def __init__(self, axes):
        self.axes =axes 
    def forward(self, a):
        # a: (bs, num_classes)
        max_a = np.max(a, axis=self.axes, keepdims=True) #Keep for broadcasting
        sub_a = a - max_a
        exp_sub = np.exp(sub_a)
        
        sum_exp = np.sum(exp_sub, axis=self.axes, keepdims=False) 
        # max_a.reshape(-1) turns (bs, 1) into (bs,) to match sum_exp
        return max_a.reshape(-1) + np.log(sum_exp)

    def backward(self, out_grad: Tensor, node: Tensor):
        a = node._inputs[0]

        #out_grad shape (bs,)
        #unsqueeze 
        #(new_shape) = (bs,num_classes)
        new_shape = list(a.shape)
        axes = self.axes
        if axes is None:
            axes = tuple(range(len(a.shape)))
        elif isinstance(axes, int):
            axes = (axes,)
        for axis in axes:
            new_shape[axis] = 1

        # (bs,1)
        new_shape = tuple(new_shape)

        #softmax 
        ## max_a_val shape: (bs, 1)
        max_a_val = a.data.max(axis=self.axes, keepdims=True)
        max_a_tensor = Tensor(max_a_val, device=a.device, dtype=a.dtype)
        
        # shifted_a shape: (bs, num_classes)
        shifted_a = a - max_a_tensor
        exp_shifted_a = exp(shifted_a)
        # sum_exp_shifted_a shape: (bs,)
        sum_exp_shifted_a = summation(exp_shifted_a, self.axes)
        # reshaped_sum shape: (bs, 1)
        reshaped_sum = reshape(sum_exp_shifted_a, new_shape)

        #(bs, num_classes) / (bs, 1) broadcasted
        softmax = divide(exp_shifted_a, broadcast_to(reshaped_sum, a.shape))
        # reshaped_out_grad shape: (bs, 1)
        reshaped_out_grad = reshape(out_grad, new_shape)
        (bs, 1) broadcasted * (bs, num_classes)
        #grad* softmax
        grad = multiply(broadcast_to(reshaped_out_grad, a.shape), softmax)
        
        return grad
def logsumexp(a: Tensor, axes: Optional[tuple] = None) -> Tensor:
    return LogSumExp(axes=axes)(a)
class SoftmaxLoss(Module):
    def forward(self, logits, y):
        n, k = logits.shape
        y_one_hot = Tensor.one_hot(y, k, requires_grad=False)
        logsumexp_val = ops.logsumexp(logits, axes=(1,))
        h_y = (logits * y_one_hot).sum(axes=(1,)) 
        return (logsumexp_val - h_y).sum() / n

11.4 Optimizer

11.4.1 SGD

We will go over each parameter and update the parameter only if parameter.grad is not None.

The reason we use parameter.data to update the weights is because we just want to have new weights and the computation graph should not be tracking this step operation.


class SGD(Optimizer):
    def __init__(self, params, lr=0.01):
        super().__init__(params)
        self.lr = lr
    def step(self):
        for param in self.params:
            if param.grad is not None:
                # Note: The update is performed on the .data attribute
                # We update the raw numpy data directly
                # to avoid creating a new computational graph 
                # for the update itself.
                param.data -= self.lr * param.grad

11.4.2 Adam



class Adam(Optimizer):
    """
    Implements the Adam optimization algorithm.
    """
    def __init__(
        self,
        params,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.0,
    ):
        super().__init__(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
        self.t = 0 
        self.m = {}
        self.v = {} 
    def step(self):        
        self.t += 1
        for param in self.params:
            if param.grad is not None:
                grad = param.grad
                if self.weight_decay > 0:
                    grad = grad + self.weight_decay * param.data
                mt = self.m.get(param, 0) * self.beta1 +
                     (1 - self.beta1) * grad
                self.m[param] = mt
                vt = self.v.get(param, 0) * self.beta2 + (1 - self.beta2)
                         * (grad ** 2)
                self.v[param] = vt
                mt_hat = mt / (1 - (self.beta1 ** self.t))

                vt_hat = vt / (1 - (self.beta2 ** self.t))

                denom = (vt_hat**0.5 + self.eps)
                param.data -= self.lr * mt_hat / denom

11.5 Data Handling

11.5.1 MNIST

Index could be :

  • Single Number : We convert them to np.array , reshape and then wrap around Tensor , apply transforms if applicable and return as (x,y).
  • Slice: We convert them to np.array , reshape and do the same thing.
class MNISTDataset(Dataset):
    def __init__(
        self,
        image_filename: str,
        label_filename: str,
        transforms: Optional[List] = None,
    ):
        self.images, self.labels = parse_mnist(image_filesname=image_filename
            , label_filename=label_filename) 
        self.transforms = transforms
    def __getitem__(self, index) -> object:
        if isinstance(index, slice):
            #lets take [0:5]
            #self.images[0:5]
            # we get back (5,784)
            # we reshape to (5,28,28,1)
            
            images_batch_flat = np.array(self.images[index], dtype=np.float32)            
            images_batch_reshaped = images_batch_flat.reshape(-1, 28, 28, 1)
            #we convert into # (5,28,28,1)
            
            labels_batch = np.array(self.labels[index])
            return (images_batch_reshaped, labels_batch)
        
        else:  #single index , return directly 
            sample_image = self.images[index]
            sample_label = self.labels[index]
            
            np_sample_image = np.array(sample_image, dtype=np.float32)
                            .reshape(28, 28, 1)
            np_sample_label = np.array(sample_label)
            if self.transforms is not None:
                for tform in self.transforms:
                    np_sample_image = tform(np_sample_image)
            return (np_sample_image, np_sample_label)
    def __len__(self) -> int:
        return len(self.images)

11.5.2 Dataloader

We have to know :

  • How many batches we can have?
  • What is the start index of each batch.
  • Then just do indices[start: start+batch_size]

class DataLoader:
    def __init__(self,dataset:Dataset, batch_size:int=1,shuffle: bool):
        self.dataset = dataset
        self.shuffle = shuffle
        self.batch_size = batch_size

    def __iter__(self):
        self.indices = np.arange(len(self.dataset))
        if self.shuffle:
            np.random.shuffle(self.indices)
    
        self.batch_idx = 0
        self.num_batches=(len(self.dataset)+self.batch_size-1)//self.batch_size
        return self
    def __next__(self):
        if self.batch_idx >= self.num_batches:
            raise StopIteration
        start = self.batch_idx * self.batch_size
        batch_indices = self.indices[start: start+self.batch_size]

        #Calls Dataset.__getitem__(i)
        samples = [self.dataset[i] for i in batch_indices]
        unzipped_samples = zip(*samples)
        all_arrays = [np.stack(s) for s in unzipped_samples]
        batch = tuple(Tensor(arr) for arr in all_arrays)
        self.batch_idx += 1
        return batch
    

11.6 Initialization

Get the code for Tensor.randn , Tensor.rand from the Solutions of Tensor chapter.

11.6.1 Xavier

def xavier_uniform(fan_in: int, fan_out: int, gain: float = 1.0,
         shape=None, **kwargs):
    kwargs.pop('device', None)    
    a = gain * math.sqrt(6.0 / (fan_in + fan_out))
    if shape is None:
        shape = (fan_in, fan_out)
    return Tensor.rand(*shape, low=-a, high=a, **kwargs)

def xavier_normal(fan_in: int, fan_out: int, gain: float = 1.0,
                 shape=None, **kwargs):
    kwargs.pop('device', None)    
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    if shape is None:
        shape = (fan_in, fan_out)

    return Tensor.randn(*shape, mean=0, std=std, **kwargs)

11.6.2 Kaiming



def kaiming_uniform(fan_in: int, fan_out: int, nonlinearity: str = "relu",
                 shape=None, **kwargs):
    kwargs.pop('device', None)
    bound = math.sqrt(2.0) * math.sqrt(3.0 / fan_in)
    if shape is None:
        shape = (fan_in, fan_out)
    return Tensor.rand(*shape, low=-bound, high=bound, **kwargs)


def kaiming_normal(fan_in: int, fan_out: int, nonlinearity: str = "relu",
             shape=None, **kwargs):
    kwargs.pop('device', None)    
    std = math.sqrt(2.0 / fan_in)
    if shape is None:
        shape = (fan_in, fan_out)
    return Tensor.randn(*shape, mean=0, std=std, **kwargs)

11.7 Model Persistence

11.7.1 Save Model

We will

  • save dict[key] = value.data , where key = Tensor/Parameter.
  • Call state_dict() if key is Module.

Calling state_dict() on a Module returns a dictionary with {layer : tensor.data}.


    def state_dict(self):
        state_dic ={}
        for key,value in self.__dict__.items():
            if isinstance(value, Tensor) or isinstance(value, Parameter):
                state_dic[key] = value.data
            elif isinstance(value, Module):
                child_sd =value.state_dict()
                for k,v in child_sd.items():
                    state_dic[f"{key}.{k}"] = v

            elif isinstance(value, (list, tuple)):
                for i, item in enumerate(value):
                    if isinstance(item, Module):
                        child_sd = item.state_dict()
                        for k, v in child_sd.items():
                            state_dic[f"{key}.{i}.{k}"] = v
        return state_dic  

11.7.2 Load Model

We need to load the values from state_dict in value.data.

  • For Tensor/Parameter we can directly call value.data = state_dict[key].
  • For Module we have to understand that we saved the keys as key.childkey inside state_dict But the Module doesn’t care or want to know that. It only cares and has state_dict[childkey]=value.data.

So we would first like to filter our state_dict and remove the prefix key. and then call the recursion.

def load_state_dict(self,state_dict):
    for key,value in self.__dict__.items():
        if isinstance(value, Parameter) or isinstance(value,Tensor):
            if key in state_dict:

                if (value.shape != state_dict[key].shape):
                    raise ValueError(f"Shape mismatch for {key}:
                     expected {value.shape}, got {state_dict[key].shape}")
                value.data = state_dict[key]
            
        elif isinstance(value, Module):
            prefix = f"{key}."
            child_sd = {
                k[len(prefix):]: v 
                for k, v in state_dict.items() 
                if k.startswith(prefix)
            }
            value.load_state_dict(child_sd)

        elif isinstance(value, (list, tuple)):
            for i, item in enumerate(value):
                if isinstance(item, Module):
                    prefix = f"{key}.{i}."
                    child_sd = {
                        k[len(prefix):]: v 
                        for k, v in state_dict.items() 
                        if k.startswith(prefix)
                    }
                    item.load_state_dict(child_sd)

11.8 Trainer


class Trainer:
    def __init__(self, model, optimizer, loss_fn, train_loader, val_loader=None):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.train_loader = train_loader
        self.val_loader = val_loader
    
    def fit(self, epochs: int):        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            num_batches = 0
            
            print(f"--- Epoch {epoch+1}/{epochs} ---")
            
            for batch_idx, batch in enumerate(self.train_loader):
                if isinstance(batch, (list, tuple)):
                     x, y = batch
                else:
                     x, y = batch.x, batch.y
                
                if not isinstance(x, Tensor): x = Tensor(x)
                
                self.optimizer.zero_grad()
                
                pred = self.model(x)
                
                loss = self.loss_fn(pred, y)
                
                loss.backward()
                
                self.optimizer.step()
                
                total_loss += loss.data
                num_batches += 1

                if batch_idx % 50 == 0:
                    # Calculate accuracy for this batch
                    y_np = y.data if isinstance(y, Tensor) else y
                    preds = pred.data.argmax(axis=1)
                    batch_acc = (preds == y_np).mean()
                    
                    print(f"Batch {batch_idx:3d}: Loss = {loss.data:.4f} | 
                            Acc = {batch_acc*100:.2f}%")
            
            avg_loss = total_loss / num_batches
            print(f"End of Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}", end="")

            if self.val_loader is not None:
                val_acc = self.evaluate()
                print(f" | Val Acc: {val_acc*100:.2f}%")
            else:
                print()

    def evaluate(self, loader=None):
        target_loader = loader if loader is not None else self.val_loader
        if target_loader is None:
            return 0.0
        
        self.model.eval() 
        correct = 0
        total = 0
        
        for batch in target_loader:
            if isinstance(batch, (list, tuple)): 
                x, y = batch
            else: 
                x, y = batch.x, batch.y
            
            if not isinstance(x, Tensor): x = Tensor(x)

            logits = self.model(x)
            
            y_np = y.data if isinstance(y, Tensor) else y
            preds = logits.data.argmax(axis=1)
            
            correct += (preds == y_np).sum()
            total += y_np.shape[0]
        
        return correct / total

11.9 Convolutional nn

Forward

Tip

What is a Contiguous Array? A Contiguous Array is an array where the data is stored in a single, unbroken block in memory.

So when we use as_strided our array then becomes non-contiguous. Before doing the reshaping we must first call np.ascontiguousarray() so that the array becomes contiguous.


class Conv(Function):
    def __init__(self, stride: int = 1, padding: int = 0):
        self.stride = stride
        self.padding = padding

    def forward(self, A, B):
        pad = self.padding 
        stride = self.stride         
        if pad > 0:
            A = np.pad(A, ((0,0), (pad,pad), (pad,pad), (0,0)))
        
        N, H, W, C_in = A.shape 
        K, _, _, C_out = B.shape 
        Ns, Hs, Ws, Cs = A.strides

        H_out = (H - K) // stride + 1
        W_out = (W - K) // stride + 1
        view_shape = (N, H_out, W_out, K, K, C_in)
        view_strides = (Ns, Hs * stride, Ws * stride, Hs, Ws, Cs)   

        A_view=np.lib.stride_tricks.as_strided(A,shape=view_shape,strides=view_strides)
        
        inner_dim = K * K * C_in
        A_matrix = np.ascontiguousarray(A_view).reshape((-1, inner_dim))
        B_matrix = B.reshape((inner_dim, C_out))
        # B is already contiguous
        
        out = A_matrix @ B_matrix
        return out.reshape((N, H_out, W_out, C_out))

11.9.1 Flip

We just call np.flip in the forward and flip in the backward.

class Flip(Function):
    def __init__(self, axes=None):
        self.axes = axes 
    def compute(self,a):
        return np.flip(a, self.axes)
    def gradient(self,out_grad,node):
        return flip(out_grad,self.axes)
    
def flip(a, axes):
    return Flip(axes)(a)

11.9.2 Dilate

For each current axis we have to find its new axis

We just need to find the new axes and then use slice.

\[Newaxis = Oldaxis + (Oldaxis - 1) \times Dilation\]


class Dilate(Function):
    def __init__(self,axes, dilation ):
        self.axes = axes 
        self.dilation = dilation

    def forward(self, a):
        new_shape = list(a.shape)
        slices = [slice(None)] * a.ndim 
        for axis in self.axes:
            new_shape[axis]=a.shape[axis]+(a.shape[axis]-1)*self.dilation
            slices[axis] = slice(0,new_shape[axis], self.dilation+1)

        out = np.zeros(tuple(new_shape),dtype=a.dtype)
        out[tuple(slices)] = a 
        return out 
    def backward(self,out_grad, node):
        slices = [slice(None)] * out_grad.ndim
        for axis in self.axes:
            slices[axis] = slice(0, None, self.dilation + 1)
        return out_grad.data[tuple(slices)] 
        

def dilate(a, axes, dilation=1):
    return Dilate(axes, dilation)(a)

11.9.3 Backward


def backward(self, out_grad, node):
    A, B = node._inputs
    
    stride = self.stride
    padding = self.padding
    
    N, H, W, C_in = A.shape
    K, _, _, C_out = B.shape
    
    
    if stride > 1:
        grad_out_dilated = dilate(out_grad, (1, 2), stride - 1)
    else:
        grad_out_dilated = out_grad

    #grad_a 

    B_transposed = B.transpose((2, 3))
    B_flipped = flip(B_transposed, (0, 1)) 
    
    
    grad_A_padding = K - 1 - padding
    grad_A = conv(grad_out_dilated, B_flipped, stride=1,
                    padding=grad_A_padding)
    
    #grad_B 

    A_permuted = A.transpose((0, 3))
    grad_out_permuted = grad_out_dilated.transpose((0, 1)).transpose((1, 2))
    
    grad_B_intermediate = conv(A_permuted, grad_out_permuted, stride=1,
                            padding=padding)
    
    grad_B = grad_B_intermediate.transpose((0, 1)).transpose((1, 2))
    
    return grad_A, grad_B

11.9.4 nn.Conv


class Conv(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                bias=True, device=None, dtype="float32"):
        super().__init__()
        if isinstance(kernel_size, tuple):
            kernel_size = kernel_size[0]
        if isinstance(stride, tuple):
            stride = stride[0]
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.weight = Parameter(init.kaiming_uniform(
            fan_in=in_channels * kernel_size * kernel_size, 
            fan_out=out_channels * kernel_size * kernel_size, 
            shape=(self.kernel_size, self.kernel_size,
                     self.in_channels, self.out_channels), 
            device=device, 
            dtype=dtype
        ))        
        if bias:
            fan_in = in_channels * kernel_size * kernel_size
            bound = 1.0 / math.sqrt(fan_in)
            self.bias = Parameter(init.rand(out_channels, low=-bound,
                         high=bound, device=device, dtype=dtype))
        else:
            self.bias = None
        self.padding = (self.kernel_size - 1) // 2
        
    def forward(self, x: Tensor) -> Tensor:
        # input: NCHW -> NHWC
        x = x.transpose((1, 2)).transpose((2, 3))
        x = ops.conv(x, self.weight, self.stride, self.padding)
        if self.bias is not None:
            bias = ops.reshape(self.bias, (1, 1, 1, self.out_channels))
            bias = ops.broadcast_to(bias, x.shape)
            x = x + bias
        return x.transpose((2, 3)).transpose((1, 2))