9  Trainer

For any training loop we will write something like this:

model = Mymodel()
optimizer = myOptimizer()
dataloader = MyDataloader()
for epoch in range(10):
    for batch in dataloader:
        optimizer.zero_grad()
        pred = model(batch.x)
        loss = loss_fn(pred, batch.y)
        loss.backward()
        optimizer.step()

This works fine, but it becomes repetitive. Every model and every experiment requires writing the same boilerplate. To make our workflow cleaner, we can wrap this logic inside a Trainer class.

Note

This is how our folder structure currently looks like. In this chapter we will work inside babygrad/trainer.py.


project/
├─ .venv/                    
├─ babygrad/
|   ├─ trainer.py
│   ├─ __init__.py          
│   ├─ data.py              
│   ├─ init.py              
│   ├─ ops.py
│   ├─ tensor.py
│   ├─ nn.py
│   └─ optim.py            
├─ examples/
│   └─ simple_mnist.py
└─ tests/

What does the Trainer class need? It needs everything that is used in the training loops.

trainer=Trainer(model,optimizer,loss_fn,train_loader,val_loader=test_loader)

print("Starting Training...")
trainer.fit(EPOCHS)

File : babygrad/trainer.py

ImportantExercise 9.1

Lets write the fit method inside Trainer class.

from babygrad.tensor import Tensor

class Trainer:
    def __init__(self, model, optimizer, loss_fn, train_loader,
                 val_loader=None):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.train_loader = train_loader
        self.val_loader = val_loader

    def fit(self, epochs: int):
        """
        Runs the training loop for the specified number of epochs.
        """
        for epoch in range(epochs):
            self.model.train() # Set mode to training

            total_loss = 0
            
            # Your solution here:
            # 1. Iterate over self.train_loader
            # 2. Get batch data (x, y)
            # 3. Zero Gradients
            # 4. Forward Pass
            # 5. Compute Loss
            # 6. Backward Pass
            # 7. Optimizer Step
            
            print(f"Epoch {epoch+1} Done.")

    def evaluate(self, loader=None):
        """
        Calculates accuracy on the validation set.
        """
        target_loader = loader if loader is not None else self.val_loader
        if target_loader is None:
            return 0.0
        # Hint:
        # 1. Set model to eval mode: self.model.eval()
        # 2. Loop over self.val_loader
        # 3. Forward pass only (no backward)
        # 4. Compare predictions to true labels(use argmax(axis=1))
        # 5. Sum the correct predictions and calculate the average.
        pass