8  Model Persistence

After training your model you turn off your computer and go to sleep or take a walk. When you turn your computer on again to find your model, It doesn’t exist.

So We need to store our model state. When storing our model should we store entire details of the model? Or the necessary ones? How can we know what is important and what is not ?

When loading our model we expect to load the exact state of the model that we trained. This implies that we already have the code and architecture. So What should we save then? What is changing in a model? The parameters change.
So only the parameters must be stored.

So we need a way to :

Note

This is how our folder structure currently looks like. In this chapter we will work inside babygrad/nn.py.


project/
├─ .venv/                    
├─ babygrad/
│   ├─ __init__.py          
│   ├─ data.py              
│   ├─ init.py              
│   ├─ ops.py
│   ├─ tensor.py
│   ├─ nn.py
│   └─ optim.py            
├─ examples/
│   └─ simple_mnist.py
└─ tests/

8.1 Save Model

{
  'layer1.weight': ([[0.12, -0.53, ... ]]),
  'layer1.bias': ([0.01, 0.04, ...]),
}

Inside a Model We find different type of objects.

  • Tensor/Parameter objects.
  • Linear/Module objects.(Inside them we find Tensor/Parameter objects).
  • Sequential/List/Tuple objects.
# Ideally, we want to save our trained model like this:
model = Sequential(Linear(10, 20), ReLU(), Linear(20, 1))
# ... train ...
model.save("my_model.pkl")

As mentioned above, we need to handle these 3 cases (Tensor, Module, List) and then return a dictionary object. Which consists of:

  • Key: A string representing the flattened path to the parameter (e.g., “layer1.weight” or “layers.0.bias”). We use dot notation to join the names as we go deeper into the recursion.
  • Value: The raw numpy array(accessed via .data). We don’t want to store Tensor object(which has graph details) but Tensor.data

File : babygrad/nn.py

ImportantExercise 8.1

Lets write the state_dict inside Module class.

import pickle 
class Module:

    def state_dict(self):
        """
        Returns a dictionary containing a whole state of the module.
        Recursively retrieves parameters from sub-modules.
        Output: dictionary{layerName: tensor.data}
        """
        #intialize a dict 
        #Iterate over self.__dict__
        #If Value is  Parameter/Tensor save dict[key] = value.data
        #If Value is Module , recurse(value.state_dict()) and
                iterate over result(child_key,v) and
                store as dict[f"{key}.{child_key}] = v

        #If list/tuple, recurse if module and add key=(key.idx.child_key) 
        #return dict 

    def save(self,filename):
        with open(filename , 'wb') as f :
            pickle.dump(self.state_dict(), f)
Note

Don’t store the Tensor object just store tensor.data.

Can you use recursion here? If you find a Module , you should call state_dict() on it! .

8.2 Load Model

We need to load the saved model. We can’t blindly copy from the dictionary of the saved model. The empty model architecture must have the same shape as the saved model architecture or else loading will not be possible.

*Architecture should match.

If you are loading your model state on a different architecture it won’t work. Th

  • For Tensor/Parameter we can directly call value.data = state_dict[key].
  • For Module we have to understand that we saved the keys as key.childkey inside state_dict But the Module doesn’t care or want to know that. It only cares and has state_dict[childkey]=value.data

So We would first like to filter our state_dict and remove the prefix key and then call the recursion on the new_state_dict.

File : babygrad/nn.py

ImportantExercise 8.2

Lets write the load_state_dict inside Module class.


import pickle 
class Module:
    def load_state_dict(self, state_dict)-> None:
        """
        Returns a dictionary containing a whole state of the module.
        Recursively retrieves parameters from sub-modules.
        Output: Nothing
        """
        # Iterate over self.__dict__
        
        # Case 1: Value is Parameter/Tensor
        #   - Check if key exists in state_dict (Safety Check)
        #   - Check if shapes match 
        #   - Assign value.data = state_dict[key] 

        # Case 2: Value is Module
        #   - Filter state_dict for keys starting with "{key}."
        #   - Remove the prefix from the keys
        #   - Call value.load_state_dict(child_state_dict)
        
        # Case 3: Value is List/Tuple
        #   - Iterate index i
        #   - Filter keys starting with "{key}.{i}."
        #   - Remove prefix
        #   - Call value.load_state_dict(child_state_dict)

        
    def load(self, filename):
        with open(filename,'rb') as f:
            self.load_state_dict(pickle.load(f))
Note

Can you use recursion here?

Just like saving, if you find a Module inside your class, you must recurse! Call value.load_state_dict(child_state_dict) to pass the data down

Do not replace the Tensor object itself! Just update the numbers inside it (value.data = state_dict[key])

Warning

Our current implementation saves the model weights, which is perfect for Inference (using the model to make predictions) or Fine-Tuning (training on a new dataset).

However, if you use this to pause and resume the same training run, be aware of a small catch:

We are not saving the Optimizer state.

Optimizers like Adam or SGD with Momentum keep track of their own internal history (running averages of gradients). When you restart the script, the Optimizer starts fresh (memory wiped). The model might take a few batches to “warm up” and stabilize again.