Created with Solve It. The course and platform for solving problems with code.Start learning →

Note: 154

Learning RoPE (Rotary Positional Embeddings)

RoPE is used for relative positional embeddings. GPT2 used absolute positional embeddings, and it was added to token embeddings. RoPE is different, and it is applied to the keys and queries of the attention mechanism.

RoPE uses rotation matrices to encode positional information. The rotation matrices are applied to the query and key vectors before computing the attention scores. This allows the model to learn relative positional information between tokens.

Let's start by implementing the RoPE algorithm.

Code: 24 (4:57:20p)

import torch
import torch.nn as nn
from typing import Optional, Tuple

Code: 39 (4:57:20p)

# Settings
batch_size = 1
context_len = 3
num_heads = 2
head_dim = 16

Note: 60

Let's create queries and keys right after they are split by num_head and headim. So, they have the shape (batch_size, context_len, num_heads, head_dim).

Code: 64 (4:57:20p)

torch.manual_seed(123)
queries = torch.randn(batch_size, context_len, num_heads, head_dim)
keys = torch.randn(batch_size, context_len, num_heads, head_dim)
queries.shape, keys.shape

Output: 60

(torch.Size([1, 3, 2, 16]), torch.Size([1, 3, 2, 16]))

Note: 66

Llama

Below is the code from Llama 2 implementation. Let's try to understand it.

Note: 10

precompute_freqs_cis

Code: 391 (4:57:20p)

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

Code: 81 (4:57:20p)

dim = head_dim
end = context_len

theta = 10000.0

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs

Output: 129

tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])

Code: 30 (4:57:20p)

freq_cis = precompute_freqs_cis(dim, end, theta)
freq_cis.shape

Output: 30

torch.Size([3, 8])

Note: 7

Breaking down into smaller steps

Code: 15 (4:57:20p)

torch.arange(0, dim, 2)

Output: 63

tensor([ 0,  2,  4,  6,  8, 10, 12, 14])

Code: 28 (4:57:20p)

torch.arange(0, dim, 2)[: (dim // 2)].float()

Output: 64

tensor([ 0.,  2.,  4.,  6.,  8., 10., 12., 14.])

Prompt: 31

Why does it have [: (dim // 2)]? What is the purpose of this?

Assistant: 369

The [: (dim // 2)] slice ensures we only take the first half of the dimensions.

Looking at your example:

  • dim = 16 (head_dim)
  • torch.arange(0, dim, 2) gives [0, 2, 4, 6, 8, 10, 12, 14] - that's 8 elements
  • dim // 2 = 8
  • So [: (dim // 2)] takes the first 8 elements, which is all of them in this case

The reason for this slice is that RoPE works in pairs - each pair of dimensions gets rotated together. So for a 16-dimensional vector, you have 8 pairs: (0,1), (2,3), (4,5), etc.

The torch.arange(0, dim, 2) generates the even indices [0, 2, 4, 6, 8, 10, 12, 14], and the slice ensures you get exactly dim//2 frequency values - one for each rotation pair. Each frequency will be applied to rotate one pair of dimensions.

Code: 33 (4:57:20p)

(torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)

Output: 91

tensor([0.0000, 0.1250, 0.2500, 0.3750, 0.5000, 0.6250, 0.7500, 0.8750])

Prompt: 15

Why do we divide everything by dim?

Assistant: 217

The division by dim creates normalized frequency positions that range from 0 to just under 1.

This normalization ensures that:

  1. The frequencies are independent of the actual head dimension size
  2. Lower dimensions get lower frequencies (slower rotation)
  3. Higher dimensions get higher frequencies (faster rotation)

When you raise theta (10000.0) to these normalized powers, you get a geometric progression of frequencies that decay exponentially. The smallest dimensions rotate very slowly (good for capturing long-range dependencies), while larger dimensions rotate faster (good for shorter-range patterns).

Without the /dim normalization, the frequency pattern would change dramatically based on the head size, making the model behavior inconsistent across different architectures.

Code: 37 (4:57:20p)

(theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

Output: 129

tensor([1.0000e+00, 3.1623e+00, 1.0000e+01, 3.1623e+01, 1.0000e+02, 3.1623e+02,
        1.0000e+03, 3.1623e+03])

Prompt: 46

Why do we (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))?

Assistant: 222

We raise theta to these normalized powers to create a geometric sequence of wavelengths that get exponentially longer.

Taking the reciprocal 1.0 / (theta ** ...) gives us frequencies that decay exponentially:

  • First pair: frequency = 1.0 (fastest rotation)
  • Second pair: frequency ≈ 0.316 (slower)
  • Third pair: frequency = 0.1 (even slower)
  • And so on...

This creates a spectrum where:

  • Lower dimension pairs rotate quickly (capture short-range patterns)
  • Higher dimension pairs rotate very slowly (capture long-range dependencies)

The exponential decay ensures each dimension pair operates at a distinctly different timescale, allowing the model to encode positional information across multiple ranges simultaneously.

Prompt: 52

Can you create a simple plot with (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))?

Assistant: 160

import matplotlib.pyplot as plt

powers = (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_powers = (theta ** powers)

plt.figure(figsize=(10, 6))
plt.plot(range(len(theta_powers)), theta_powers, 'bo-')
plt.xlabel('Dimension pair index')
plt.ylabel('theta^(power)')
plt.title('Exponential growth of theta powers')
plt.yscale('log')
plt.grid(True)
plt.show()

Code: 156 (4:57:21p)

import matplotlib.pyplot as plt

powers = (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_powers = (theta ** powers)

plt.figure(figsize=(10, 6))
plt.plot(range(len(theta_powers)), theta_powers, 'bo-')
plt.xlabel('Dimension pair index')
plt.ylabel('theta^(power)')
plt.title('Exponential growth of theta powers')
# plt.yscale('log')
plt.grid(True)
plt.show()

Output: 100

Code: 45 (4:57:21p)

1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

Output: 129

tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])

Prompt: 18

Why do we take the inverse of what we had before?

Assistant: 180

We take the inverse to convert wavelengths into frequencies.

The theta ** powers gives us exponentially growing wavelengths (how far apart the rotations are). But for RoPE, we need frequencies (how fast the rotations happen).

Since frequency = 1/wavelength, taking the inverse 1.0 / (theta ** powers) gives us:

  • High frequencies for early dimension pairs (fast rotation, short-range patterns)
  • Low frequencies for later dimension pairs (slow rotation, long-range dependencies)

This creates the desired frequency spectrum where different dimension pairs rotate at different speeds to encode positional information across multiple scales.

Prompt: 7

Let's graph this too.

Assistant: 144

frequencies = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

plt.figure(figsize=(10, 6))
plt.plot(range(len(frequencies)), frequencies, 'ro-')
plt.xlabel('Dimension pair index')
plt.ylabel('Frequency')
plt.title('Exponential decay of frequencies')
plt.yscale('log')
plt.grid(True)
plt.show()

Code: 139 (4:57:22p)

frequencies = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

plt.figure(figsize=(10, 6))
plt.plot(range(len(frequencies)), frequencies, 'ro-')
plt.xlabel('Dimension pair index')
plt.ylabel('Frequency')
plt.title('Exponential decay of frequencies')
# plt.yscale('log')
plt.grid(True)
plt.show()

Output: 100

Prompt: 69

What do you mean 3.1623e-04 is a high frequency? I thought it was a small number, so it should be a low frequency. Because it is equivalent to 0.00031623.

Assistant: 225

You're absolutely right! I mixed that up. Let me correct that:

  • High frequencies (like 1.0) rotate quickly → capture short-range positional relationships
  • Low frequencies (like 3.1623e-04 = 0.00031623) rotate slowly → capture long-range positional relationships

Think of it like a clock:

  • A second hand (high frequency) completes many rotations quickly - good for distinguishing nearby moments
  • An hour hand (low frequency) rotates slowly - good for distinguishing longer time periods

In RoPE, the high frequency dimensions help the model distinguish between adjacent tokens, while the low frequency dimensions help distinguish between tokens that are far apart in the sequence.

Thanks for catching that error!

Prompt: 45

Each head dimension gets different frequencies, which means different rate of rotations. So, the higher the dimension, the slower the rotation. Is that right?

Assistant: 240

Exactly right! Looking at the code:

torch.arange(0, dim, 2)[: (dim // 2)].float() / dim

This creates values like [0.0000, 0.1250, 0.2500, 0.3750, ...] that increase with dimension index.

When we do theta ** (increasing_values), we get larger denominators for higher dimensions, resulting in smaller frequencies.

So yes:

  • Lower dimension indices → higher frequencies → faster rotation
  • Higher dimension indices → lower frequencies → slower rotation

This creates a nice hierarchy where early dimensions capture fine-grained positional differences (adjacent tokens) and later dimensions capture coarse-grained positional differences (distant tokens).

Code: 19 (4:57:22p)

t = torch.arange(end, device=freqs.device)
t

Output: 33

tensor([0, 1, 2])

Code: 21 (4:57:22p)

freqs = torch.outer(t, freqs).float()
freqs

Output: 351

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
         1.0000e-03, 3.1623e-04],
        [2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
         2.0000e-03, 6.3246e-04]])

Prompt: 6

Let's plot this.

Assistant: 109

plt.figure(figsize=(12, 8))
for i in range(freqs.shape[1]):
    plt.plot(freqs[:, i], label=f'dim_pair_{i}', marker='o')
plt.xlabel('Position')
plt.ylabel('Frequency * Position')
plt.title('Frequency scaling across positions')
plt.legend()
plt.grid(True)
plt.show()

Code: 103 (4:57:22p)

plt.figure(figsize=(12, 8))
for i in range(freqs.shape[1]):
    plt.plot(freqs[:, i], label=f'dim_pair_{i}', marker='o')
plt.xlabel('Position')
plt.ylabel('Frequency * Position')
plt.title('Frequency scaling across positions')
plt.legend()
plt.grid(True)
plt.show()

Output: 100

Prompt: 15

What is the purpose of plotting it this way?

Assistant: 229

This plot shows how each dimension pair's rotation angle changes linearly with position.

The key insight is that freqs contains the actual rotation angles that will be applied:

  • Position 0: all angles are 0 (no rotation)
  • Position 1: each dimension gets its base frequency as the angle
  • Position 2: each dimension gets 2× its base frequency as the angle

The linear scaling means:

  • High frequency dimensions (dim_pair_0) accumulate rotation angles quickly
  • Low frequency dimensions (dim_pair_7) accumulate rotation angles slowly

This creates the relative positional encoding - tokens at different positions will have different rotation angles, and the difference between any two positions depends on both their distance and the dimension pair's frequency.

Prompt: 10

Let's also plot this in positions.

Assistant: 117

plt.figure(figsize=(12, 8))
for i in range(freqs.shape[0]):
    plt.plot(freqs[i, :], label=f'pos_{i}', marker='o')
plt.xlabel('Dimension pair')
plt.ylabel('Frequency * Position')
plt.title('Rotation angles by dimension pair at each position')
plt.legend()
plt.grid(True)
plt.show()

Code: 111 (4:57:23p)

plt.figure(figsize=(12, 8))
for i in range(freqs.shape[0]):
    plt.plot(freqs[i, :], label=f'pos_{i}', marker='o')
plt.xlabel('Dimension pair')
plt.ylabel('Frequency * Position')
plt.title('Rotation angles by dimension pair at each position')
plt.legend()
plt.grid(True)
plt.show()

Output: 100

Prompt: 39

By using torch.outer, are we scaling the frequencies by the position? The farther away, the more the frequencies are scaled?

Assistant: 259

Exactly! torch.outer(t, freqs) multiplies each frequency by each position index.

Looking at your output:

  • Position 0: All frequencies × 0 = 0 (no rotation)
  • Position 1: All frequencies × 1 = original frequencies
  • Position 2: All frequencies × 2 = doubled frequencies

So yes, the farther the position, the more the frequencies are scaled up. This means:

  • Token at position 0 gets no rotation
  • Token at position 1 gets rotated by the base frequencies
  • Token at position 2 gets rotated by 2× the base frequencies
  • Token at position n gets rotated by n× the base frequencies

This creates the relative positional encoding effect - tokens that are further apart will have larger rotation differences between them, making their dot products in attention naturally encode distance information.

Code: 39 (4:57:23p)

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
freqs_cis

Output: 597

tensor([[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
          1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
          1.0000+0.0000e+00j,  1.0000+0.0000e+00j],
        [ 0.5403+8.4147e-01j,  0.9504+3.1098e-01j,  0.9950+9.9833e-02j,
          0.9995+3.1618e-02j,  0.9999+9.9998e-03j,  1.0000+3.1623e-03j,
          1.0000+1.0000e-03j,  1.0000+3.1623e-04j],
        [-0.4161+9.0930e-01j,  0.8066+5.9113e-01j,  0.9801+1.9867e-01j,
          0.9980+6.3203e-02j,  0.9998+1.9999e-02j,  1.0000+6.3245e-03j,
          1.0000+2.0000e-03j,  1.0000+6.3246e-04j]])

Code: 7 (4:57:23p)

freqs_cis.real

Output: 273

tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9504,  0.9950,  0.9995,  0.9999,  1.0000,  1.0000,  1.0000],
        [-0.4161,  0.8066,  0.9801,  0.9980,  0.9998,  1.0000,  1.0000,  1.0000]])

Code: 6 (4:57:23p)

torch.polar?

Output: 672

Docstring:
polar(abs, angle, *, out=None) -> Tensor

Constructs a complex tensor whose elements are Cartesian coordinates
corresponding to the polar coordinates with absolute value :attr:`abs` and angle
:attr:`angle`.

.. math::
    \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j

.. note::
    `torch.polar` is similar to
    `std::polar <https://en.cppreference.com/w/cpp/numeric/complex/polar>`_
    and does not compute the polar decomposition
    of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do.
    The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is
    infinite.


Args:
    abs (Tensor): The absolute value the complex tensor. Must be float or double.
    angle (Tensor): The angle of the complex tensor. Must be same dtype as
        :attr:`abs`.

Keyword args:
    out (Tensor): If the inputs are ``torch.float32``, must be
        ``torch.complex64``. If the inputs are ``torch.float64``, must be
        ``torch.complex128``.

Example::

    >>> import numpy as np
    >>> abs = torch.tensor([1, 2], dtype=torch.float64)
    >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64)
    >>> z = torch.polar(abs, angle)
    >>> z
    tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
Type:      builtin_function_or_method

Prompt: 16

Why are we using torch.polar here?

Assistant: 204

torch.polar converts the frequencies into complex numbers representing rotations.

It takes:

  • abs = magnitude (all 1s, so unit circle)
  • angle = the scaled frequencies we calculated

This creates complex numbers of the form cos(angle) + i*sin(angle), which are rotation operators in the complex plane.

So freqs_cis contains the actual rotation amounts for each position and dimension. When we multiply query/key vectors by these complex numbers, we're rotating them by the corresponding angles.

The name "cis" comes from the mathematical notation cis(θ) = cos(θ) + i*sin(θ).

Prompt: 37

I can understand why we are adjusting frequencies based on positions, but why are we applying it to the head dimensions as well?

Assistant: 330

Great question! We apply different frequencies to different head dimensions because we want each dimension pair to encode positional information at a different "scale" or "resolution."

Think of it like this:

  • High frequency dimensions (early dimensions) can distinguish between positions 1 vs 2, but positions 100 vs 101 might look similar after many rotations
  • Low frequency dimensions (later dimensions) might not distinguish positions 1 vs 2 well, but can clearly distinguish positions 1 vs 100

This creates a hierarchy where:

  • Some dimensions are good at encoding fine-grained position differences (nearby tokens)
  • Other dimensions are good at encoding coarse-grained position differences (distant tokens)

By having this spectrum of frequencies across dimensions, the model can capture both local and global positional relationships simultaneously. It's like having multiple "clocks" running at different speeds - some for seconds, some for minutes, some for hours.

Without this variety, all dimensions would encode position at the same scale, which would be less expressive for capturing the full range of positional relationships in a sequence.

Prompt: 33

Okay. And we are dividing head dimension by two and calculate half for cosine and half for sine, right?

Assistant: 271

Yes, exactly! RoPE works on pairs of dimensions. Each pair gets rotated together:

  • Dimension 0 & 1 share the same frequency
  • Dimension 2 & 3 share the same frequency
  • And so on...

So for head_dim=16, we only need 8 frequencies (for 8 pairs). The rotation is applied as:

[x0, x1] → [x0*cos(θ) - x1*sin(θ), x0*sin(θ) + x1*cos(θ)]

This is why we see torch.arange(0, dim, 2) - we're stepping by 2 to get frequencies for each pair, and [: (dim // 2)] ensures we get exactly half the dimensions worth of frequencies.

The complex representation with torch.polar elegantly handles this 2D rotation math for us.

Note: 7 + 928

reshape_for_broadcast

Note: 7

apply_rotary_emb

Code: 502 (4:57:23p)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Prompt: 10

Why is it called an embedding?

Assistant: 207

It's called "rotary positional embedding" because it's embedding (encoding) positional information into the query and key vectors.

Traditional positional embeddings add position information to token embeddings before attention. RoPE is different - it embeds positional information directly into the attention mechanism by rotating the query and key vectors based on their positions.

The rotation angles encode where each token is positioned in the sequence. When two tokens at different positions have their vectors rotated by different amounts, their dot product (attention score) naturally reflects their relative distance.

So "embedding" here means "encoding positional information into" rather than the typical meaning of converting discrete tokens to vectors. It's embedding position into the geometry of the vector space through rotations.

Note: 84

pasted_image

Prompt: 33

Displayed above is a plot of absolute positional embeddings in gpt2. What can you tell me from this?

Assistant: 1,524

Looking at this plot of GPT-2's absolute positional embeddings, I can tell you several key insights:

The plot shows three lines representing the first three principal components of GPT-2's positional embeddings. Research has found that GPT-2's 768-dimensional positional vectors mostly live in a low-rank linear subspace, with a three-dimensional subspace explaining 90% of the variance ^1 ^2 ^3.

When these positional embeddings are plotted in 3D PCA space, they form a helix structure ^4 ^5 ^6. The three lines in your plot are the projections of this helix onto each of the three main PCA dimensions.

Key observations from the plot:

  1. Smooth, continuous patterns: The embeddings don't jump randomly but follow smooth trajectories, suggesting the model learned structured positional representations.

  2. Different frequencies: Each PCA dimension oscillates at different rates - some change quickly (high frequency) while others change more gradually (low frequency).

  3. Helical structure: The dark blue end represents early positions in the sequence, while the red end represents later positions ^7 ^8 ^9, forming a spiral pattern in 3D space.

  4. Efficient encoding: This can be thought of as the neural network reserving 3 dimensions for positional information and 765 for semantic content ^10.

This helical structure is quite different from RoPE, which explicitly uses sinusoidal functions. GPT-2 learned this structure through training, discovering an efficient way to encode position that allows for both local and global positional relationships.

Prompt: 36

Displayed above is a plot of absolute positional embeddings in gpt2. Can I make a similar plot with RoPE?

Assistant: 847

Yes, you can create a similar visualization for RoPE! Unlike GPT-2's absolute positional embeddings that are added to tokens, RoPE creates rotational patterns that you can visualize.

Here are a few ways to plot RoPE:

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Rotation angles over positions for different dimensions
for i in range(min(8, freqs_cis.shape[1])):
    axes[0,0].plot(torch.angle(freqs_cis[:, i]), label=f'dim_{i}', marker='o')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position')
axes[0,0].legend()
axes[0,0].grid(True)

# Plot 2: Real and imaginary parts
axes[0,1].plot(freqs_cis[:, 0].real, label='Real (cos)', marker='o')
axes[0,1].plot(freqs_cis[:, 0].imag, label='Imag (sin)', marker='s')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)

# Plot 3: 2D rotation visualization
theta = torch.angle(freqs_cis[:, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-')
for i, pos in enumerate(range(len(theta))):
    axes[1,0].annotate(f'pos_{i}', (x_pos[i], y_pos[i]))
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')

# Plot 4: Heatmap of all rotation angles
angles = torch.angle(freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])

plt.tight_layout()
plt.show()

Code: 772 (4:57:24p)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Rotation angles over positions for different dimensions
for i in range(min(8, freqs_cis.shape[1])):
    axes[0,0].plot(torch.angle(freqs_cis[:, i]), label=f'dim_{i}', marker='o')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position')
axes[0,0].legend()
axes[0,0].grid(True)

# Plot 2: Real and imaginary parts
axes[0,1].plot(freqs_cis[:, 0].real, label='Real (cos)', marker='o')
axes[0,1].plot(freqs_cis[:, 0].imag, label='Imag (sin)', marker='s')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)

# Plot 3: 2D rotation visualization
theta = torch.angle(freqs_cis[:, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-')
for i, pos in enumerate(range(len(theta))):
    axes[1,0].annotate(f'pos_{i}', (x_pos[i], y_pos[i]))
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')

# Plot 4: Heatmap of all rotation angles
angles = torch.angle(freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])

plt.tight_layout()
plt.show()

Output: 102

Code: 10 (4:57:24p)

torch.angle(freqs_cis)

Output: 351

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
         1.0000e-03, 3.1623e-04],
        [2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
         2.0000e-03, 6.3246e-04]])

Prompt: 63

Let's try the plotting wtih more practical example. How about we do the same of sequence length and embedding as gpt2? Create variables different from the ones we tried so we do not overwrite them.

Assistant: 87

gpt2_seq_len = 1024
gpt2_head_dim = 64

gpt2_freqs_cis = precompute_freqs_cis(gpt2_head_dim, gpt2_seq_len)
gpt2_freqs_cis.shape

Prompt: 10

why is head_dim 64?

Assistant: 403

GPT-2 has 768 dimensional states and 12 attention heads ^1 ^2 ^3. The head dimension is calculated as:

head_dim = embedding_dim / num_heads = 768 / 12 = 64

So each attention head gets 64 dimensions to work with. This is a design choice that divides the full embedding dimension evenly across all attention heads, allowing each head to focus on different aspects of the relationships between tokens.

Code: 81 (4:57:24p)

gpt2_context_len = 1024
gpt2_head_dim = 64

gpt2_freqs_cis = precompute_freqs_cis(gpt2_head_dim, gpt2_context_len)
gpt2_freqs_cis.shape

Output: 31

torch.Size([1024, 32])

Prompt: 10

Okay, let's plot them now.

Assistant: 690

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

for i in range(min(8, gpt2_freqs_cis.shape[1])):
    axes[0,0].plot(torch.angle(gpt2_freqs_cis[:, i]), label=f'dim_{i}')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position (GPT-2 scale)')
axes[0,0].legend()
axes[0,0].grid(True)

axes[0,1].plot(gpt2_freqs_cis[:, 0].real, label='Real (cos)')
axes[0,1].plot(gpt2_freqs_cis[:, 0].imag, label='Imag (sin)')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)

theta = torch.angle(gpt2_freqs_cis[:100, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-', markersize=2)
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern (first 100 positions)')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')

angles = torch.angle(gpt2_freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])

plt.tight_layout()
plt.show()

Code: 684 (4:57:25p)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

for i in range(min(8, gpt2_freqs_cis.shape[1])):
    axes[0,0].plot(torch.angle(gpt2_freqs_cis[:, i]), label=f'dim_{i}')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position (GPT-2 scale)')
axes[0,0].legend()
axes[0,0].grid(True)

axes[0,1].plot(gpt2_freqs_cis[:, 0].real, label='Real (cos)')
axes[0,1].plot(gpt2_freqs_cis[:, 0].imag, label='Imag (sin)')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)

theta = torch.angle(gpt2_freqs_cis[:100, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-', markersize=2)
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern (first 100 positions)')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')

angles = torch.angle(gpt2_freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])

plt.tight_layout()
plt.show()

Output: 102

Prompt: 19

I don't understand what these plots mean. Can you explain?

Assistant: 459

Looking at these plots, here's what each one shows:

Top Left - Rotation angles by position: Each colored line shows how the rotation angle increases for different dimension pairs as position increases. Higher frequency dimensions (like dim_0) rotate much faster and create dense oscillations, while lower frequency dimensions rotate more slowly.

Top Right - First dimension pair (cos/sin): This shows the cosine and sine values for the fastest-rotating dimension. You can see the rapid oscillation - this dimension completes many full rotations across the 1024 positions.

Bottom Left - 2D rotation pattern: This shows the first 100 positions plotted as points on a unit circle. Each point represents where that position "lands" after rotation. The spiral pattern shows how consecutive positions are rotated by small increments, creating a smooth progression around the circle.

Bottom Right - All rotation angles heatmap: This is the most informative. The x-axis is position (0-1024), y-axis is dimension pair (0-31). The colors represent rotation angles. You can see:

  • Top rows (high frequency dims) have rapid color changes (fast rotation)
  • Bottom rows (low frequency dims) have gradual color changes (slow rotation)
  • The rainbow pattern shows how different dimensions encode position at different scales

Think of it like having 32 different "clocks" - some tick every second (high freq), others tick every hour (low freq). Together they can uniquely identify any position in the sequence.

Prompt: 33

Thanks for the explanation, but it is very hard to visualize what top plots are doing. Can you help?

Assistant: 567

Let me create some interactive visualizations to help you understand what the top plots are showing:

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

pos = 0
dim_pairs = [0, 1, 2, 3]
colors = ['red', 'blue', 'green', 'orange']

for i, dim_pair in enumerate(dim_pairs):
    angles = torch.angle(gpt2_freqs_cis[:100, dim_pair])
    ax1.plot(angles, label=f'dim_pair_{dim_pair}', color=colors[i], linewidth=2)

ax1.set_xlabel('Position')
ax1.set_ylabel('Rotation angle (radians)')
ax1.set_title('How rotation angles grow with position')
ax1.legend()
ax1.grid(True)

for i, dim_pair in enumerate(dim_pairs):
    freqs = 1.0 / (10000 ** (torch.arange(0, 64, 2)[dim_pair].float() / 64))
    ax2.bar(i, freqs, color=colors[i], label=f'dim_pair_{dim_pair}')

ax2.set_xlabel('Dimension pair')
ax2.set_ylabel('Frequency (rotations per position)')
ax2.set_title('Base frequencies for each dimension pair')
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print("Key insights:")
print("- Left plot: Shows how rotation angles accumulate linearly with position")
print("- Right plot: Shows base frequencies - higher dims rotate much slower")
print("- dim_pair_0 completes ~1 rotation every position")
print("- dim_pair_3 completes ~0.03 rotations every position")

Code: 534 (4:57:26p)

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

pos = 0
dim_pairs = [0, 1, 2, 3]
colors = ['red', 'blue', 'green', 'orange']

for i, dim_pair in enumerate(dim_pairs):
    angles = torch.angle(gpt2_freqs_cis[:100, dim_pair])
    ax1.plot(angles, label=f'dim_pair_{dim_pair}', color=colors[i], linewidth=2)

ax1.set_xlabel('Position')
ax1.set_ylabel('Rotation angle (radians)')
ax1.set_title('How rotation angles grow with position')
ax1.legend()
ax1.grid(True)

for i, dim_pair in enumerate(dim_pairs):
    freqs = 1.0 / (10000 ** (torch.arange(0, 64, 2)[dim_pair].float() / 64))
    ax2.bar(i, freqs, color=colors[i], label=f'dim_pair_{dim_pair}')

ax2.set_xlabel('Dimension pair')
ax2.set_ylabel('Frequency (rotations per position)')
ax2.set_title('Base frequencies for each dimension pair')
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print("Key insights:")
print("- Left plot: Shows how rotation angles accumulate linearly with position")
print("- Right plot: Shows base frequencies - higher dims rotate much slower")
print("- dim_pair_0 completes ~1 rotation every position")
print("- dim_pair_3 completes ~0.03 rotations every position")

Output: 204

Key insights:
- Left plot: Shows how rotation angles accumulate linearly with position
- Right plot: Shows base frequencies - higher dims rotate much slower
- dim_pair_0 completes ~1 rotation every position
- dim_pair_3 completes ~0.03 rotations every position

Note: 33

Rembmer queries and keys have shape: (batch_size, context_len, num_heads, head_dim)

Code: 34 (4:57:26p)

xq = queries
xk = keys

xq.shape, xk.shape, freqs_cis.shape

Output: 72

(torch.Size([1, 3, 2, 16]), torch.Size([1, 3, 2, 16]), torch.Size([3, 8]))

Code: 9 (4:57:26p)

xq.shape[:-1]

Output: 34

torch.Size([1, 3, 2])

Code: 28 (4:57:26p)

xq.float().reshape(*xq.shape[:-1], -1, 2).shape

Output: 43

torch.Size([1, 3, 2, 8, 2])

Note: 42

Last dimension is 2, which are cos and sin. We can use view_as_complex to convert it to a complex number.

Code: 45 (4:57:26p)

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xq_.shape

Output: 39

torch.Size([1, 3, 2, 8])

Code: 7 (4:57:26p)

torch.view_as_complex?

Output: 567

Docstring:
view_as_complex(input) -> Tensor

Returns a view of :attr:`input` as a complex tensor. For an input complex
tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a
new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last
dimension of the input tensor is expected to represent the real and imaginary
components of complex numbers.

.. warning::
    :func:`view_as_complex` is only supported for tensors with
    :class:`torch.dtype` ``torch.float64`` and ``torch.float32``.  The input is
    expected to have the last dimension of :attr:`size` 2. In addition, the
    tensor must have a `stride` of 1 for its last dimension. The strides of all
    other dimensions must be even numbers.

Args:
    input (Tensor): the input tensor.

Example::

    >>> x=torch.randn(4, 2)
    >>> x
    tensor([[ 1.6116, -0.5772],
            [-1.4606, -0.9120],
            [ 0.0786, -1.7497],
            [-0.6561, -1.6623]])
    >>> torch.view_as_complex(x)
    tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)])
Type:      builtin_function_or_method

Code: 45 (4:57:26p)

xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xk_.shape

Output: 39

torch.Size([1, 3, 2, 8])

Code: 33 (4:57:26p)

freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis.shape

Output: 39

torch.Size([1, 3, 1, 8])

Note: 3

Broadcasting

Code: 15 (4:57:26p)

(xq_ * freqs_cis).shape

Output: 39

torch.Size([1, 3, 2, 8])

Prompt: 10

How does the multiplication work here?

Assistant: 358

The multiplication xq_ * freqs_cis is complex number multiplication, which performs the rotation.

When you multiply two complex numbers:

(a + bi) * (cos(θ) + i*sin(θ)) = (a*cos(θ) - b*sin(θ)) + i*(a*sin(θ) + b*cos(θ))

This is exactly the 2D rotation formula! So each complex number in xq_ gets rotated by the corresponding angle in freqs_cis.

The broadcasting works because:

  • xq_ has shape [1, 3, 2, 8] (batch, seq_len, heads, head_dim_pairs)
  • freqs_cis has shape [1, 3, 1, 8] (batch, seq_len, 1, head_dim_pairs)

The 1 in the heads dimension broadcasts across all 2 heads, so each head gets the same rotational encoding for the same position and dimension pair.

Each position gets its own rotation angles, and each dimension pair gets its own frequency, but all heads share the same positional encoding.

Note: 58

Apply the rotary transformation

rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)

Prompt: 19

From a tensor, how can I get the even indexed elements?

Assistant: 159

You can get even indexed elements using slicing with a step of 2, starting from index 0:

tensor[..., ::2]

This selects every 2nd element starting from index 0 (so indices 0, 2, 4, 6, etc.) along the last dimension.

For other dimensions, you can specify the dimension:

tensor[::2]  # even indices along first dimension
tensor[:, ::2]  # even indices along second dimension

Prompt: 6

how about odd?

Assistant: 88

You can get odd indexed elements by starting from index 1:

tensor[..., 1::2]

This selects every 2nd element starting from index 1 (so indices 1, 3, 5, 7, etc.) along the last dimension.

Prompt: 1,321

I have two following tensors:

tensor([[-1.2295, -0.2887,  0.9687, -0.6554,  1.3035,  0.1295,  0.1323, -0.8838],
        [ 0.5342,  0.5383,  0.8112, -0.4330,  0.5602, -0.9264, -0.1925, -1.5116],
        [ 1.4182,  0.3159, -1.0963,  0.1530,  0.6830, -0.2398,  0.0732,  0.6918],
        [ 0.6567,  0.6113, -2.1482,  1.1618,  0.4768, -0.5982,  2.5378,  1.1881]])
tensor([[-0.6163,  1.1063, -1.7300, -0.1095, -0.7442, -0.8126,  0.2128,  0.2878],
        [-0.7322, -1.1714,  2.1736,  0.6815,  0.6597, -0.5675, -0.1627, -2.4516],
        [ 0.7278, -0.1520,  1.5938,  0.4412, -0.9143, -0.7768,  0.5236,  0.2211],
        [ 0.4192, -0.1826,  0.6669, -1.7079,  1.9377, -0.7297, -1.2567,  0.1236]])

I want to combine them like this:

tensor([[-1.2295, -0.6163, -0.2887,  1.1063,  0.9687, -1.7300, -0.6554, -0.1095,
          1.3035, -0.7442,  0.1295, -0.8126,  0.1323,  0.2128, -0.8838,  0.2878],
        [ 0.5342, -0.7322,  0.5383, -1.1714,  0.8112,  2.1736, -0.4330,  0.6815,
          0.5602,  0.6597, -0.9264, -0.5675, -0.1925, -0.1627, -1.5116, -2.4516],
        [ 1.4182,  0.7278,  0.3159, -0.1520, -1.0963,  1.5938,  0.1530,  0.4412,
          0.6830, -0.9143, -0.2398, -0.7768,  0.0732,  0.5236,  0.6918,  0.2211],
        [ 0.6567,  0.4192,  0.6113, -0.1826, -2.1482,  0.6669,  1.1618, -1.7079,
          0.4768,  1.9377, -0.5982, -0.7297,  2.5378, -1.2567,  1.1881,  0.1236]])

How do I do this?

Assistant: 97

You can use torch.stack and then flatten to interleave the tensors:

torch.stack([tensor1, tensor2], dim=-1).flatten(-2)

This stacks the tensors along the last dimension, creating pairs, then flattens the last two dimensions to interleave them.

Code: 7 (4:57:26p)

torch.view_as_real?

Output: 454

Docstring:
view_as_real(input) -> Tensor

Returns a view of :attr:`input` as a real tensor. For an input complex tensor of
:attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new
real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2
represents the real and imaginary components of complex numbers.

.. warning::
    :func:`view_as_real` is only supported for tensors with ``complex dtypes``.

Args:
    input (Tensor): the input tensor.

Example::

    >>> x=torch.randn(4, dtype=torch.cfloat)
    >>> x
    tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)])
    >>> torch.view_as_real(x)
    tensor([[ 0.4737, -0.3839],
            [-0.2098, -0.6699],
            [ 0.3470, -0.9451],
            [-0.5174, -1.3136]])
Type:      builtin_function_or_method

Note: 21

Going back to real numbers, which adds a dimension of 2.

Code: 21 (4:57:26p)

torch.view_as_real(xq_ * freqs_cis).shape

Output: 43

torch.Size([1, 3, 2, 8, 2])

Note: 21

Use flatten(3) to flatten starting from the third dimension.

Code: 27 (4:57:26p)

torch.view_as_real(xq_ * freqs_cis).flatten(3).shape

Output: 39

torch.Size([1, 3, 2, 16])

Code: 37 (4:57:26p)

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xq_out.shape

Output: 39

torch.Size([1, 3, 2, 16])

Code: 37 (4:57:26p)

xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
xk_out.shape

Output: 39

torch.Size([1, 3, 2, 16])

Note: 40

That's it. We have applied rotary embeddings to the queries and keys.

Let's run it all together. Here are the settings we used.

Code: 16 (4:57:26p)

batch_size, context_len, num_heads, head_dim

Output: 33

(1, 3, 2, 16)

Code: 30 (4:57:26p)

freq_cis = precompute_freqs_cis(head_dim, context_len)
freq_cis.shape

Output: 30

torch.Size([3, 8])

Code: 39 (4:57:26p)

xq, xk = apply_rotary_emb(queries, keys, freq_cis)
xq.shape, xk.shape

Output: 60

(torch.Size([1, 3, 2, 16]), torch.Size([1, 3, 2, 16]))

Code: 33 (4:57:26p)

assert torch.allclose(xq, xq_out) and torch.allclose(xk, xk_out)

Note: 3

They match

Note: 75

Rashka's implementation

from gpt to llama

Note: 12

Comparing it to Llama implementation.

Code: 327 (4:57:26p)

def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    # angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)
    # angles = torch.repeat_interleave(angles, 2, dim=1)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

Code: 16 (4:57:26p)

batch_size, context_len, num_heads, head_dim

Output: 33

(1, 3, 2, 16)

Code: 36 (4:57:27p)

cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos.shape, sin.shape

Output: 42

(torch.Size([3, 8]), torch.Size([3, 8]))

Note: 46

inv_freq and positions are equivalent to llama. To make comparison clearer, we use llama_ prefix for llama's implementation.

Code: 91 (4:57:27p)

inv_freq = 1.0 / (10_000 ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
positions = torch.arange(context_len)
angles = positions[:, None] * inv_freq[None, :]
angles.shape

Output: 30

torch.Size([3, 8])

Code: 10 (4:57:27p)

theta_base = 10_000

Code: 72 (4:57:27p)

llama_freqs = torch.outer(t, 1.0 / (theta_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).float()
llama_freqs.shape

Output: 30

torch.Size([3, 8])

Note: 21

angles and llama_freqs are the same.

Code: 16 (4:57:27p)

assert torch.allclose(angles, llama_freqs)

Note: 102

Instead of transforming angles/frequencies into complex space, Rashka's implementation directly calculates cos and sin. Rashka's implementation uses torch.repeat_interleave to repeat the angles, which is equivalent to reshape_for_broadcast. However, to compare with llama's cis, we will skip the repeating part to make it easier to compare.

Code: 27 (4:57:27p)

angles = torch.repeat_interleave(angles, 2, dim=1)
angles.shape

Output: 30

torch.Size([3, 16])

Code: 28 (4:57:27p)

cos = torch.cos(angles)
sin = torch.sin(angles)
cos.shape, sin.shape

Output: 42

(torch.Size([3, 16]), torch.Size([3, 16]))

Note: 37

To compare those, we have to repeat the llama_freqs_cis to match the shape of angles.

Code: 81 (4:57:27p)

llama_freqs_cis = torch.polar(torch.ones_like(llama_freqs), llama_freqs)
llama_freqs_cis = torch.repeat_interleave(llama_freqs_cis, 2, dim=1)
llama_freqs_cis.shape

Output: 30

torch.Size([3, 16])

Code: 18 (4:57:27p)

torch.stack([cos, sin], dim=-1).shape

Output: 34

torch.Size([3, 16, 2])

Code: 25 (4:57:27p)

torch.view_as_complex(torch.stack([cos, sin], dim=-1)).shape

Output: 30

torch.Size([3, 16])

Code: 37 (4:57:27p)

assert torch.allclose(torch.view_as_complex(torch.stack([cos, sin], dim=-1)), llama_freqs_cis)

Note: 51

By converting cos and sin to complex numbers, we can see that they are the same as llama's implementation.

Next, let's look at how to apply the rotary embeddings.

Code: 331 (4:57:27p)

def compute_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    return x_rotated.to(dtype=x.dtype)

Note: 51

Rashcka's implementation takes x with shape (batch_size, num_heads, context_len, head_dim). We need to transpose it to match llama's implementation.

Code: 37 (4:57:27p)

# queries and keys have shape: (batch_size, context_len, num_heads, head_dim)
queries.shape, keys.shape

Output: 60

(torch.Size([1, 3, 2, 16]), torch.Size([1, 3, 2, 16]))

Code: 45 (4:57:27p)

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos.shape, sin.shape

Output: 42

(torch.Size([3, 8]), torch.Size([3, 8]))

Note: 7

Let's work on queries first

Code: 42 (4:57:27p)

# Convert to (batch_size, num_heads, context_len, head_dim)
x = queries.transpose(1, 2)
x.shape

Output: 39

torch.Size([1, 2, 3, 16])

Code: 22 (4:57:27p)

x1 = x[..., : head_dim // 2]
x1.shape

Output: 39

torch.Size([1, 2, 3, 8])

Code: 21 (4:57:27p)

x2 = x[..., head_dim // 2 :]
x2.shape

Output: 39

torch.Size([1, 2, 3, 8])

Code: 88 (4:57:27p)

cos = cos[:context_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, context_len, head_dim)
sin = sin[:context_len, :].unsqueeze(0).unsqueeze(0)
cos.shape, sin.shape

Output: 60

(torch.Size([1, 1, 3, 8]), torch.Size([1, 1, 3, 8]))

Code: 28 (4:57:27p)

rotated = torch.cat((-x2, x1), dim=-1)
rotated.shape

Output: 39

torch.Size([1, 2, 3, 16])

Code: 31 (4:57:27p)

x_rotated = (x1 * cos) + (x2 * sin)
x_rotated.shape

Output: 39

torch.Size([1, 2, 3, 8])

Code: 6 (4:57:27p)

xq_out.shape

Output: 39

torch.Size([1, 3, 2, 16])

Code: 28 (4:57:27p)

torch.cat([(x1 * cos), (x2 * sin)], dim=-1).shape

Output: 39

torch.Size([1, 2, 3, 16])

Code: 48 (4:57:28p)

assert torch.allclose(torch.cat([(x1 * cos), (x2 * sin)], dim=-1).transpose(1,2), xq_out)

Output: 117

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[154], line 1
----> 1 assert torch.allclose(torch.cat([(x1 * cos), (x2 * sin)], dim=-1).transpose(1,2), xq_out)

AssertionError: 

Code: 16 (4:57:38p)

xq_.shape, cos.shape, x1.shape

Output: 81

(torch.Size([1, 3, 2, 8]), torch.Size([1, 1, 3, 8]), torch.Size([1, 2, 3, 8]))

Code: 19 (4:57:38p)

(cos * x1).transpose(1,2).shape

Output: 39

torch.Size([1, 3, 2, 8])

Code: 18 (4:57:38p)

(cos * x1).transpose(1,2)

Output: 507

tensor([[[[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196,
           -0.3792],
          [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753,
           -0.2010]],

         [[-0.7159,  0.1696, -2.1231,  1.0518, -0.3885, -0.9343, -0.4991,
           -1.0867],
          [ 0.4738,  1.5417, -1.4706,  1.1325, -1.2202,  1.3138,  1.0533,
            0.1388]],

         [[-0.2142,  0.8015, -0.2536, -1.0804, -0.0444,  1.6235, -2.3229,
            1.0878],
          [-0.1062, -0.4433,  0.9842,  0.8256, -0.3947,  0.4892, -0.2168,
           -1.7472]]]])

Code: 28 (4:57:38p)

xq.float().reshape(*xq.shape[:-1], -1, 2).shape

Output: 43

torch.Size([1, 3, 2, 8, 2])

Code: 43 (4:57:38p)

xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 0].transpose(1, 2)

Output: 504

tensor([[[[ 0.3374, -0.3035,  0.3486, -0.2196,  0.7671,  0.6984,  0.1794,
            0.4954],
          [-0.8660, -2.3552, -0.2933, -0.4645,  0.8650,  0.6272,  0.0984,
            0.2664],
          [-1.1178,  0.4312, -0.3660, -2.3870,  0.6576, -0.9482, -0.1529,
            0.4412]],

         [[-0.0770, -0.1690,  1.5810,  1.2753,  0.4965,  0.9666, -1.1589,
           -0.6315],
          [-0.8912, -1.7570, -1.3454,  1.0484,  2.2553, -0.2833, -0.6588,
            0.1838],
          [ 0.3935,  0.3210, -0.4841, -0.1060, -1.5807,  0.9077, -0.5937,
            0.6238]]]])

Code: 9 (4:57:38p)

x1.shape, x1

Output: 529

(torch.Size([1, 2, 3, 8]),
 tensor([[[[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196,
            -0.3792],
           [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991,
            -1.0867],
           [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,
             1.0878]],
 
          [[-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753,
            -0.2010],
           [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,
             0.1388],
           [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168,
            -1.7472]]]]))

Code: 64 (4:57:38p)

cos_reshaped = torch.repeat_interleave(cos, 2, dim=1)
sin_reshaped = torch.repeat_interleave(sin, 2, dim=1)
cos_reshaped.shape

Output: 39

torch.Size([1, 2, 3, 8])