10  Convolutional NN

Lets take an image of (64,64,6) (H,W,C) dimensions, Can we train this single image so that our model predicts the ideal output? One obvious and simple thing would be to simply use a Linear layer and do the training.

model = nn.Linear(64*64*6,10)
image = tensor.randn(1,64,64,6)
image = image.reshape(1,-1)
output = model(image)

This is doable but not preferable. Why? Lets think in term of amount of computation required

\[ 64*64*6 =24576 \]

Big weights for just a simple image. What if the image was bigger?

TipWhat is an Image Tensor?

While a standard Tensor is often just a 2D matrix (rows \(\times\) columns), an Image Tensor is strictly 4-Dimensional to handle batches of visual data efficiently.

It follows the NHWC convention(We will use this):

  • N (Batch Size): The number of images in the stack.
  • C (Channels): The depth (e.g., 3 for RGB, 1 for Grayscale).
  • H (Height):
  • W (Width):

Some libraries use NCHW!

TipWhat are Spatial Dimensions?

Dimensions which represents a physical space(H,W).

Why are big weights a problem? You might ask: “Aren’t big models better? Don’t we want more parameters to learn more things?”

When we use a Linear layer on image pixels, we expect it to learn meaningful features of the image. However, this approach has major limitations:

Linear layers for images leads to:

10.1 Convolutional Neural Networks

TO understand an images features We can’t look at the image completely at once but little by little. CNN’s will understand the spatial features and will have less parameters than the above model. They use a kernel or filter that is a small matrix of shape (K,K). This small matrix will move (convolve) over the spatial dimensions of the image and does the dot product to produce one single output pixel value.

At ever step :

  • Look at the (k,k) patch in the image.
  • Perform dot product of this above patch with kernel.
  • Produce a single output pixel.
  • Move kernel to next step.

Does it really use less weights? less computation? We use the same kernel for every patch, instead of learning separate weights for every pixel. This reduces the number of parameters and computations.

10.2 Hyperparameters

The kernel slides but can we control how it should slide, how much it should slide?

10.2.1 Padding

When a kernel moves over an image, it starts from top-left corner and slides across the image. But clearly we can observe that the edge pixels are used less and also our output gets smaller. If we want to use the edge pixels equally and also to have the same output size we would consider padding.

Padding is just adding additional zero’s in the spatial dimensions.

How much to pad? Can we decide it? If we want the Output Size to equal the Input Size (assuming Stride=1), this simple formula tells us the padding

\[ \text{Padding} = \frac{\text{Kernel Size} - 1}{2} \]

10.2.2 Stride

We mentioned that the kernel moves (convolves) to the next (k,k) patch of the image after each step. But how much should it move?

The default is Stride = 1. The kernel moves 1 pixel at a time. This preserves the most information but keeps the output large.

However, sometimes we want to decrease the spatial dimensions of the image to save memory. We can increase the step size.

  • Stride 1: Moves 1 pixel.
  • Stride 2: Moves 2 pixels

Using all the available Hyperparameters we can safely calculate the output size by the given formula below

\[ \text{Output Size} = \frac{\text{Input }+ 2 * \text{Padding} - \text{Kernel}}{Stride} + 1 \]

10.3 Simple Convolution

Lets start implementing a simple Convolution using normal loops. Before we write the code, we need to address a common point of confusion. We previously said a kernel is size (K,K). However, when working with actual code and RGB images, our tensors are 4-Dimensional

NoteWhy 4 dimensions for kernel too?

You might ask: “I thought the kernel was just a small (K,K) square?” It is true but images also have depth or Channels

  • Input Depth Match : If the input image has C_in Channels, the kernel must also have C_in Channels to process all the Channels at once.

We will Work with :

  • Image (N,H,W,C_in):
  • Kernel(K,K,C_in, C_out)
  • Padding = 0
  • Stride=1

Lets first find the shape of the output image after once passing through the function. We know that the shape of output image must be

\[ (N,\; \text{new}height,\; \text{new}width,\; c_{\text{out}}) \]

What should the dimensions new_height and new_width be? Lets use our formula from above :

\[ \text{Output Size} = \frac{\text{Input }+ 2 * \text{Padding} - \text{Kernel}}{Stride} + 1 \]

Replace padding=0 and stride=1 we get

\[ \begin{aligned} \text{new\_height} &= \text{Height} - \text{Kernel} + 1 \\ \text{new\_width} &= \text{Width} - \text{Kernel} + 1 \end{aligned} \]

import numpy as np

def simple_conv(image, weight):
    N, H, W, C_in = image.shape
    K, _, _, C_out = weight.shape    
    H_out = H - K + 1
    W_out = W - K + 1
    out = np.zeros((N, H_out, W_out, C_out))

    # 1. Iterate over each image in the batch
    for n in range(N):
        # 2. Iterate over each feature we want to output.
        for cout in range(C_out):
            for x in range(H_out):
                for y in range(W_out):                    
                    for cin in range(C_in):
                        for i in range(K):
                            for j in range(K):
                                out[n, x, y, cout] +=
                                     image[n, x+i, y+j, cin] *
                                             weight[i, j, cin, cout]
                                
    return out
    

This above code does the job but it so slow. Lets use some magic to make it faster

10.4 Convolution as Matrix Multiplication

Matrix Multiplication is faster and heavily optimized by the computers. Can we convert our simple conv into a Matrix Multiplication and removing the for loops?

What happens if we transform the function? Why does it get faster? Instead of doing a calculation for each pixel (Output-Centric), we shift our focus to how each weight of the kernel affects the output entirely at once (Kernel-Centric).

10.5 Example:

Input Image (3×3)

\[ \mathrm{Image} = \begin{bmatrix} 1 & 4 & 7 \\ 2 & 5 & 8 \\ 3 & 6 & 9 \end{bmatrix} \]

Kernel (2×2)

\[ \mathrm{Kernel} = \begin{bmatrix} 1 & 10 \\ 100 & 1000 \end{bmatrix} \]

Each kernel weight will generate a shifted copy of the image.

Step-by-Step: Shifted Copies

Step 1 Top-Left Weight (1) This weight acts on the top-left of every sliding window. We take the top-left slice of the image and multiply by 1. Why only this patch? Why are we not including the other numbers? Imagine you slide the above kernel over the image and just focus on the kernel[0,0] pixel , what values does this pixel touch? Instead of doing it one by one we do it all at once.

\[ \mathrm{Output} += \begin{bmatrix} 1 & 4 \\ 2 & 5 \end{bmatrix} \times 1 \]

Step 2 Top-Right Weight (10) We shift the image one column to the right and multiply by 10.

\[ \mathrm{Output} += \begin{bmatrix} 4 & 7 \\ 5 & 8 \end{bmatrix} \times 10 \]


Step 3 Bottom-Left Weight (100) We shift the image one row down and multiply by 100.

\[ \mathrm{Output} += \begin{bmatrix} 2 & 5 \\ 3 & 6 \end{bmatrix} \times 100 \]


Step 4 Bottom-Right Weight (1000) We shift the image down and right and multiply by 1000.

\[ \mathrm{Output} += \begin{bmatrix} 5 & 8 \\ 6 & 9 \end{bmatrix} \times 1000 \]


Python Implementation

import numpy as np

image = np.array([
    [1, 4, 7],
    [2, 5, 8],
    [3, 6, 9]
])

kernel = np.array([
    [1, 10],
    [100, 1000]
])

output = np.zeros((2, 2), dtype=int)

output += image[0:2, 0:2] * kernel[0, 0]   # top-left
output += image[0:2, 1:3] * kernel[0, 1]   # top-right
output += image[1:3, 0:2] * kernel[1, 0]   # bottom-left
output += image[1:3, 1:3] * kernel[1, 1]   # bottom-right

output

# array([
#    [54321, 87654],
#    [65432, 98765]
#    ])

Now that we understand how we have shifted our focus, we can reduce the for loops used in the simple_conv and instead use matrix Multiplication to improve performance.

def conv_matrix_multiplication(image, weight):
    N, H, W, C_in = image.shape
    K, _, _, C_out = weight.shape    
    H_out = H - K + 1
    W_out = W - K + 1
    out = np.zeros((N, H_out, W_out, C_out))

    for i in range(K):
        for j in range(K):
            #(N, H_out, W_out, C_in) @ (C_in, C_out)
            #output (N, H_out, W_out, C_out)
            # Take All batches and all channels.
            out += image[: , i:i+H_out, j: j+W_out, : ] @ weight[i,j]
    return out 

This is a massive improvement from what we had before. But We still have 2 loops! Can we do something even better than this?

10.6 The im2col Trick

We take every \(K \times K\) patch that the kernel would see, and we stretch it out into a single row (or column).

1. Flatten the Kernel

First, we take our \((2 \times 2)\) kernel and flatten it into a single column vector.

\[ \text{Kernel} = \begin{bmatrix} 1 & 10 \\ 100 & 1000 \end{bmatrix} \xrightarrow{\text{Reshape}} \begin{bmatrix} 1 \\ 10 \\ 100 \\ 1000 \end{bmatrix} \]

2. Flatten the Input Patches (im2col)

Now, we look at our input image. We need to find the four \(2 \times 2\) patches that the kernel slides over.

\[ \text{Image} = \begin{bmatrix} 1 & 4 & 7 \\ 2 & 5 & 8 \\ 3 & 6 & 9 \end{bmatrix} \] Patch 1 (Top-Left): \[ \begin{bmatrix} 1 & 4 \\ 2 & 5 \end{bmatrix} \] → Stretch to Row: [1, 4, 2, 5]

Patch 2 (Top-Right): \[ \begin{bmatrix} 4 & 7 \\ 5 & 8 \end{bmatrix} \] → Stretch to Row: [4, 7, 5, 8]

Patch 3 (Bottom-Left): \[ \begin{bmatrix} 2 & 5 \\ 3 & 6 \end{bmatrix} \] → Stretch to Row: [2, 5, 3, 6]

Patch 4 (Bottom-Right): \[ \begin{bmatrix} 5 & 8 \\ 6 & 9 \end{bmatrix} \] → Stretch to Row: [5, 8, 6, 9]

3. The Matrix Multiplication

Now we stack these rows into a large Input Matrix (\(X_{\mathrm{col}}\)) and multiply it by our Weight Vector (\(W_{\mathrm{col}}\)).

\[ \underbrace{\begin{bmatrix} 1 & 4 & 2 & 5 \\ 4 & 7 & 5 & 8 \\ 2 & 5 & 3 & 6 \\ 5 & 8 & 6 & 9 \end{bmatrix}}_{\text{Input Matrix (N, 4)}} \times \underbrace{\begin{bmatrix} 1 \\ 10 \\ 100 \\ 1000 \end{bmatrix}}_{\text{Weights (4, 1)}} = \underbrace{\begin{bmatrix} 5241 \\ 8574 \\ 6352 \\ 9685 \end{bmatrix}}_{\text{Result (N, 1)}} \]

Check the math for the first row: \((1 \times 1) + (4 \times 10) + (2 \times 100) + (5 \times 1000) = 5241\).

4. Reshape

Finally, we take our result vector and fold it back into the output shape \((2 \times 2)\).

\[ \begin{bmatrix} 5241 \\ 8574 \\ 6352 \\ 9685 \end{bmatrix} \xrightarrow{\text{Reshape}} \begin{bmatrix} 5241 & 8574 \\ 6352 & 9685 \end{bmatrix} \]

import numpy as np

image = np.array([
    [1, 4, 7],
    [2, 5, 8],
    [3, 6, 9]
])

kernel = np.array([
    [1, 10],
    [100, 1000]
])

# Flatten the kernel (column-major order to match manual example)
W_col = kernel.flatten().reshape(-1, 1)  # [1, 10, 100, 1000]

patches = np.stack([
    image[0:2, 0:2].flatten(),  # Top-left: [1,4,2,5]
    image[0:2, 1:3].flatten(),  # Top-right: [4,7,5,8]
    image[1:3, 0:2].flatten(),  # Bottom-left: [2,5,3,6]
    image[1:3, 1:3].flatten()   # Bottom-right: [5,8,6,9]
], axis=0)

result_vector = patches @ W_col

# Reshape to 2x2 output
output = result_vector.reshape(2, 2)
print(output)
# [[5241 8574]
#  [6352 9685]]

We are just

  • Flattening the kernels and patches.
  • Matrix multiplying .
  • Reshaping

Before going to the actual python implementation of the code lets first understand a little bit about how matrix is placed inside the memory

10.6.1 Matrix and memory

When you create a (3×3) matrix in NumPy, how is it stored? Is it a grid? Or some magic way? The matrix is stored in a 1D array. It doesn’t matter how many dimensions the matrix has; all of them are stored linearly. The memory in our computer is just a linear tape of bytes. Lets take an example

A = np.array([
    [0, 1, 2],
    [3, 4, 5],
    [6, 7, 8]
], dtype=np.int64)

The above matrix is just stored in [0, 1, 2, 3, 4, 5, 6, 7, 8] this way. NumPy is doing something very beautiful under the hood to map the 2D coordinates (row, col) to this 1D tape.

TipWhat are Strides(Memory)?

Strides tell the computer: “How many bytes do I need to skip to move one step in each dimension?”

In the example above, to move to the next row, we must skip 3 items. To move to the next column, we skip 1 item.

Since we used int64, every number takes up 8 bytes.

  • Step Row: Skip 3 items (3×8 bytes) = 24 bytes.
  • Step Column: Skip 1 item (1×8 bytes) = 8 bytes.

So, A.strides is (24, 8).

Mathematically, NumPy finds the address of \(A[i, j]\) using:

\[ \text{Address} = \text{Start} + (i \times 24) + (j \times 8) \]

Now that we know what Strides are we can focus on the implementation part. As we said earlier we need to change the shape of images and weights that allow us to do matrix Multiplication. That means We need to have (N,inner_dim) image @ (inner_dim,C_out) kernel to get (N,C_out) and then we can reshape back into (N, H_out,W_out, C_out).

The question is how can we make our image shape such that we can do matrix Multiplication? Can we use reshape?

Reshape doesn’t work if the total elements of output don’t match the input. Our output pixels change in shape according to (kernel size,padding,strides). So we can't use reshape

as_strided from numpy allows us to change the shape of our image to our desired shape. Under the hood as_strided doesn’t create a new numpy array but just changes its strides without changing the data.

Hence no extra memory is being used here.


def conv_im2col(image, weight):
    N,H,W,C_in = image.shape
    K,_,_,C_out = weight.shape
    Ns, Hs, Ws, Cs = image.strides
    
    inner_dim = K * K * C_in
    A=np.lib.stride_tricks.as_strided(image,shape=(N,H-K+1, W-K+1,K,K,C_in),
                strides = (Ns, Hs, Ws, Hs, Ws, Cs)).reshape(-1,inner_dim)
    # (-1,inner_dim) @ (inner_dim, c_out)
    out = A @ weight.reshape(-1, C_out)
    return out.reshape(N,H-K+1,W-K+1,C_out)

Two obvious questions:

  • Why a 6D matrix is created.
  • Why the strides have (Hs) and (Ws) twice?

We will now focus on adding Conv to our library.

10.7 Conv Forward

FILE : baby/ops.py

ImportantExercise 2.1


class Conv(TensorOp):
    def __init__(self, stride= 1, padding=0):
        self.stride = stride
        self.padding = padding

    def forward(self, A, B):
        #your solution
        # If padding>0 use np.pad on `spatial dimensions`.

        #The im2col trick. 
    def backward(self, out_grad, node):
        #next section

def conv(a, b, stride=1, padding=1):
    return Conv(stride, padding)(a, b)

Note

Use A.pad if needed.

10.8 Conv Backward

Let’s implement backward function for Conv. Before that we need to first understand what are we trying to do in this backward function and what is the shape of out_grad and what is returned by our backward function here.

Understanding Gradient Flow in Convolution

When computing gradients for convolution, we need to propagate out_grad back to both inputs:

  • Gradient w.r.t. input (A): How does changing A affect the output?
  • Gradient w.r.t. kernel (B): How does changing B affect the output?

The shape of out_grad is (N, H_out, W_out, C_out). Our backward function must return two tensors:

  • Gradient w.r.t. A with shape (N, H, W, C_in)
  • Gradient w.r.t. B with shape (K, K, C_in, C_out)

But how do we compute these gradients? It turns out that the gradient of a convolution is itself a convolution! However, we need to carefully adjust the inputs for the convolution to make this work.

Why Flip?

In the forward pass, our kernel slides over the input. In the backward pass, we need to figure out how each input pixel contributed to the output gradients.

In the forward pass, the top-left kernel weight multiplies with many different input pixels as it slides. In the backward pass, we need to “reverse” this process. The mathematical relationship requires us to flip the kernel - what was top-left now becomes bottom-right.

10.8.1 Flip Operation

The flip operation reverses a tensor along specified axes.

Example: 3×3 Flip Original:

\[ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} \]

After flip(axes=(0,1)):

\[ \begin{bmatrix} 9 & 8 & 7 \\ 6 & 5 & 4 \\ 3 & 2 & 1 \end{bmatrix} \]

Lets implement Flip.

FILE : baby/ops.py

ImportantExercise 11.2
class Flip(Function):
    def __init__(self, axes=None):
        self.axes = axes 
    def forward(self,a):
        #your solution
    def backward(self,out_grad,node):
        #your solution
    
def flip(a, axes):
    return Flip(axes)(a)
Note

Can we use np.flip for both forward and backward?

10.8.2 Why Dilate?

When stride > 1 in the forward pass, the kernel “skips” pixels. For example, with stride=2, the kernel moves 2 pixels at a time.

In the backward pass, out_grad has fewer spatial dimensions because of this striding. To properly propagate gradients back, we need to “undo” this compression by inserting zeros between the gradient values. This matches the spacing that the kernel used during the forward pass.

10.8.3 Dilate Operation

The dilate operation inserts zeros between elements along specified axes.

Example: 3×3 Dilate with dilation=1

Original:

\[ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} \]

After dilate(axes=(0,1), dilation=1):

\[ \begin{bmatrix} 1 & 0 & 2 & 0 & 3 \\ 0 & 0 & 0 & 0 & 0 \\ 4 & 0 & 5 & 0 & 6 \\ 0 & 0 & 0 & 0 & 0 \\ 7 & 0 & 8 & 0 & 9 \end{bmatrix} \]

Notice how each value is now separated by zeros. If the forward pass used stride=2, we need dilation=1 (which means “1 zero between each element”) to restore the proper spacing.

How to implement a dilate function then? Lets assume we have a zero-filled array of the input shape, then we just need to fill the indices with out_grad matrix and skip when needed.

Lets assume we have a 1D array

a = [1,2,3]

We want to use the dilation=1. That means a 0 between each number.


new_a = [0,0,0,0,0]
# start from 0th index till last index, move 2, (dilation+1)
new_a[0,len(new_a), 2] = a

>>> [1,0,2,0,3]

How did we decide which numbers should be skipped? And how does our old dimensions change? By how much?

We just need to find the new axes and then use slice.

\[Newaxis = Oldaxis + (Oldaxis - 1) \times Dilation\]

Lets implement Dilate .

FILE : baby/ops.py

ImportantExercise 11.3

class Dilate(Function):
    def __init__(self,axes, dilation ):
        self.axes = axes 
        self.dilation = dilation

    def forward(self, a):
        new_shape = list(a.shape)
        slices = [slice(None)] * a.ndim 

        #your solution here 
        #loop over each axis to find its new axis
        #and also create `slice` just like above.
        ##
        out = np.zeros(tuple(new_shape),dtype=a.dtype)
        out[tuple(slices)] = a 
        return out 
    def backward(self,out_grad, node):
        slices = [slice(None)] * out_grad.ndim
        for axis in self.axes:
            slices[axis] = slice(0, None, self.dilation + 1)
        return out_grad.data[tuple(slices)] 
        

def dilate(a, axes, dilation=1):
    return Dilate(axes, dilation)(a)
Note

There is no np.dilate. Must implement it yourself.

Lets now try to implement conv.backward. We have 2 things to return grad_a and grad_b.

We are sure that out_grad is smaller than inputs. If stride > 1 then no matter what, we will do the dilation.

The question then becomes, what axes will be dilated and what the dilation= should be?

The dilation should be just (stride - 1) Because if stride=2 then we need to add 1 zero between each number and we dilate the axes (1,2) the spatial dimensions only.

Finding grad_a

To find the gradient with respect to our input \(A\), we need to answer: “How did each pixel in \(A\) contribute to the final error?”

We need to convert out_grad to the original shape of A. We already have a dilated out_grad. We just need to apply conv to this with a flipped kernel.

grad_A = conv(dilated_out_grad, flipped_kernel, padding=?)

What should the padding be?

Lets suppose :

  • Input (\(A\)): 7 pixels
  • Kernel (\(K\)): 3 pixels
  • padding : 0
  • Stride: 1
  • Forward Pass: \((7 + 0 - 3) + 1 = \mathbf{5}\) pixels. (We lost 2 pixels).

Now our out_grad shape is 5 pixels. We need to do something to out_grad so that it becomes 7.

Get from the 5-pixel out_grad back to a 7-pixel grad_a

\(P_{backward} = (K-1) - P_{forward}\).

  • Calculate \(P_{back}\): \((3 - 1) - 0 = \mathbf{2}\).
  • Adding padding: \(5 + (2 \times 2) = \mathbf{9}\).
  • Perform convolution: \((9 - 3) + 1 = \mathbf{7}\). Success!

In our case

  • A: (B,H,W,C_in)
  • B: (K,K,C_in,C_out)
  • outgrad: (B,H_out,W_out,C_out)

If you carefully observe to do the convolution we need the input_channels to match. That means we need to first transpose the kernel and then flip it.

\[(B, H_{out}, W_{out}, \mathbf{C_{out}}) @ \text{flipped}(K, K, \mathbf{C_{out}}, \mathbf{C_{in}}) \to (B, H, W, \mathbf{C_{in}})\]

The steps to find grad_A are.

  • Dilate out_grad using dilation = stride - 1.
  • Transpose the kernel B on axes \((2, 3)\) to swap \(C_{in}\) and \(C_{out}\).
  • Flip the transposed kernel on \((0, 1)\).
  • Calculate \(P_{back} = (K - 1) - P_{forward}\).
  • Convolve the dilated gradient with the flipped/transposed kernel using padding= \(P_{back}\)

Finding grad_b

How do we find out exactly which weight in our \(K \times K\) kernel caused the error? We treat the Input \(A\) as our new “image” and the dilated out_grad as our “kernel.”

We need to get to \((K, K, C_{in}, C_{out})\). To do this, we rearrange our tensors so the Batch dimension (\(B\)) acts as the internal channel that gets summed out.

\[(C_{in}, H, W, \mathbf{B}) @ (H_{out\_dil}, W_{out\_dil}, \mathbf{B}, C_{out}) \to (C_{in}, K, K, C_{out})\]

In the forward pass, we sum over channels. In the backward pass for weights, we sum over the batch. By swapping the Batch and Channel dimensions, we can ‘trick’ a standard convolution into doing the weight-summation for us.”

The steps to find grad_b :

  • Transpose A: (B, H, W, C_in) \(\to\) (C_in, H, W, B)
  • Transpose dilated out_grad: (B, H_out, W_out, C_out) \(\to\) (H_out, W_out, B, C_out)
  • Convolve: Use stride=1 and forward_padding.
  • Tranpose to original.

FILE : baby/ops.py

Now we can use the above helper functions to do the magic.

ImportantExercise 2.2


class Conv(Function):
    def __init__(self, stride= 1, padding=0):
        self.stride = stride
        self.padding = padding

    def forward(self, A, B):
        #your solution

    def backward(self, out_grad, node):
        #your solution 
        #hint: use flip and dilate operations

        A, B = node.inputs
        
        stride = self.stride
        padding = self.padding
        
        N, H, W, C_in = A.shape
        K, _, _, C_out = B.shape

        return grad_A, grad_B

        
def conv(a, b, stride=1, padding=1):
    return Conv(stride, padding)(a, b)
Note

Use flip and dilate operations along with convolution to compute gradients. Use conv for both the grads. Be careful with the shape. Follow the shapes .

10.9 nn.Conv

This will be pretty simple to implement we will just need to use the ops.conv we implemented earlier and thats it.

FILE : baby/nn.py

Many images dataset return in NCHW so we have to first transpose them to NHWC and then call ops.conv that was written above. And then again tranpose back to the original form.

ImportantExercise 10.2



class Conv(Module):
    """
    Most datasets use the NCHW format ,
     but in our ops.conv method we used NHWC method,
     so just be careful and use transpose.
    Accepts inputs in NCHW format, outputs also in NCHW format
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True,
                 device=None, dtype="float32"):
        super().__init__()
        if isinstance(kernel_size, tuple):
            kernel_size = kernel_size[0]
        if isinstance(stride, tuple):
            stride = stride[0]
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.weight = Parameter(init.kaiming_uniform(
            fan_in=in_channels * kernel_size * kernel_size, 
            fan_out=out_channels * kernel_size * kernel_size, 
            shape=(self.kernel_size, self.kernel_size, self.in_channels,
                 self.out_channels), 
            device=device, 
            dtype=dtype
        ))        
        if bias:
            self.bias = Parameter(init.zeros(out_channels))
        else:
            self.bias = None
        self.padding = (self.kernel_size - 1) // 2
    def forward(self, x: Tensor) -> Tensor:
        #transpose accordingly 
        # call conv 
        # add bias 
        # tranpose 
        #return 
Note

Most datasets uses NCHW format and we did NHWC format in our ops.conv so just transpose the input images and we are good to go .

You should use reshape and broadcast_to for bias.