Created with Solve It. The course and platform for solving problems with code.Start learning →
RoPE is used for relative positional embeddings. GPT2 used absolute positional embeddings, and it was added to token embeddings. RoPE is different, and it is applied to the keys and queries of the attention mechanism.
RoPE uses rotation matrices to encode positional information. The rotation matrices are applied to the query and key vectors before computing the attention scores. This allows the model to learn relative positional information between tokens.
Let's start by implementing the RoPE algorithm.
torch.manual_seed(123)
queries = torch.randn(batch_size, context_len, num_heads, head_dim)
keys = torch.randn(batch_size, context_len, num_heads, head_dim)
queries.shape, keys.shape
Below is the code from Llama 2 implementation. Let's try to understand it.
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
dim = head_dim
end = context_len
theta = 10000.0
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs
The [: (dim // 2)] slice ensures we only take the first half of the dimensions.
Looking at your example:
dim = 16 (head_dim)torch.arange(0, dim, 2) gives [0, 2, 4, 6, 8, 10, 12, 14] - that's 8 elementsdim // 2 = 8[: (dim // 2)] takes the first 8 elements, which is all of them in this caseThe reason for this slice is that RoPE works in pairs - each pair of dimensions gets rotated together. So for a 16-dimensional vector, you have 8 pairs: (0,1), (2,3), (4,5), etc.
The torch.arange(0, dim, 2) generates the even indices [0, 2, 4, 6, 8, 10, 12, 14], and the slice ensures you get exactly dim//2 frequency values - one for each rotation pair. Each frequency will be applied to rotate one pair of dimensions.
The division by dim creates normalized frequency positions that range from 0 to just under 1.
This normalization ensures that:
When you raise theta (10000.0) to these normalized powers, you get a geometric progression of frequencies that decay exponentially. The smallest dimensions rotate very slowly (good for capturing long-range dependencies), while larger dimensions rotate faster (good for shorter-range patterns).
Without the /dim normalization, the frequency pattern would change dramatically based on the head size, making the model behavior inconsistent across different architectures.
We raise theta to these normalized powers to create a geometric sequence of wavelengths that get exponentially longer.
Taking the reciprocal 1.0 / (theta ** ...) gives us frequencies that decay exponentially:
This creates a spectrum where:
The exponential decay ensures each dimension pair operates at a distinctly different timescale, allowing the model to encode positional information across multiple ranges simultaneously.
Can you create a simple plot with (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))?
import matplotlib.pyplot as plt
powers = (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_powers = (theta ** powers)
plt.figure(figsize=(10, 6))
plt.plot(range(len(theta_powers)), theta_powers, 'bo-')
plt.xlabel('Dimension pair index')
plt.ylabel('theta^(power)')
plt.title('Exponential growth of theta powers')
plt.yscale('log')
plt.grid(True)
plt.show()
import matplotlib.pyplot as plt
powers = (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_powers = (theta ** powers)
plt.figure(figsize=(10, 6))
plt.plot(range(len(theta_powers)), theta_powers, 'bo-')
plt.xlabel('Dimension pair index')
plt.ylabel('theta^(power)')
plt.title('Exponential growth of theta powers')
# plt.yscale('log')
plt.grid(True)
plt.show()
We take the inverse to convert wavelengths into frequencies.
The theta ** powers gives us exponentially growing wavelengths (how far apart the rotations are). But for RoPE, we need frequencies (how fast the rotations happen).
Since frequency = 1/wavelength, taking the inverse 1.0 / (theta ** powers) gives us:
This creates the desired frequency spectrum where different dimension pairs rotate at different speeds to encode positional information across multiple scales.
frequencies = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
plt.figure(figsize=(10, 6))
plt.plot(range(len(frequencies)), frequencies, 'ro-')
plt.xlabel('Dimension pair index')
plt.ylabel('Frequency')
plt.title('Exponential decay of frequencies')
plt.yscale('log')
plt.grid(True)
plt.show()
frequencies = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
plt.figure(figsize=(10, 6))
plt.plot(range(len(frequencies)), frequencies, 'ro-')
plt.xlabel('Dimension pair index')
plt.ylabel('Frequency')
plt.title('Exponential decay of frequencies')
# plt.yscale('log')
plt.grid(True)
plt.show()
What do you mean 3.1623e-04 is a high frequency? I thought it was a small number, so it should be a low frequency. Because it is equivalent to 0.00031623.
You're absolutely right! I mixed that up. Let me correct that:
Think of it like a clock:
In RoPE, the high frequency dimensions help the model distinguish between adjacent tokens, while the low frequency dimensions help distinguish between tokens that are far apart in the sequence.
Thanks for catching that error!
Each head dimension gets different frequencies, which means different rate of rotations. So, the higher the dimension, the slower the rotation. Is that right?
Exactly right! Looking at the code:
torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
This creates values like [0.0000, 0.1250, 0.2500, 0.3750, ...] that increase with dimension index.
When we do theta ** (increasing_values), we get larger denominators for higher dimensions, resulting in smaller frequencies.
So yes:
This creates a nice hierarchy where early dimensions capture fine-grained positional differences (adjacent tokens) and later dimensions capture coarse-grained positional differences (distant tokens).
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
1.0000e-03, 3.1623e-04],
[2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
2.0000e-03, 6.3246e-04]])plt.figure(figsize=(12, 8))
for i in range(freqs.shape[1]):
plt.plot(freqs[:, i], label=f'dim_pair_{i}', marker='o')
plt.xlabel('Position')
plt.ylabel('Frequency * Position')
plt.title('Frequency scaling across positions')
plt.legend()
plt.grid(True)
plt.show()
This plot shows how each dimension pair's rotation angle changes linearly with position.
The key insight is that freqs contains the actual rotation angles that will be applied:
The linear scaling means:
This creates the relative positional encoding - tokens at different positions will have different rotation angles, and the difference between any two positions depends on both their distance and the dimension pair's frequency.
plt.figure(figsize=(12, 8))
for i in range(freqs.shape[0]):
plt.plot(freqs[i, :], label=f'pos_{i}', marker='o')
plt.xlabel('Dimension pair')
plt.ylabel('Frequency * Position')
plt.title('Rotation angles by dimension pair at each position')
plt.legend()
plt.grid(True)
plt.show()
By using torch.outer, are we scaling the frequencies by the position? The farther away, the more the frequencies are scaled?
Exactly! torch.outer(t, freqs) multiplies each frequency by each position index.
Looking at your output:
So yes, the farther the position, the more the frequencies are scaled up. This means:
This creates the relative positional encoding effect - tokens that are further apart will have larger rotation differences between them, making their dot products in attention naturally encode distance information.
tensor([[ 1.0000+0.0000e+00j, 1.0000+0.0000e+00j, 1.0000+0.0000e+00j,
1.0000+0.0000e+00j, 1.0000+0.0000e+00j, 1.0000+0.0000e+00j,
1.0000+0.0000e+00j, 1.0000+0.0000e+00j],
[ 0.5403+8.4147e-01j, 0.9504+3.1098e-01j, 0.9950+9.9833e-02j,
0.9995+3.1618e-02j, 0.9999+9.9998e-03j, 1.0000+3.1623e-03j,
1.0000+1.0000e-03j, 1.0000+3.1623e-04j],
[-0.4161+9.0930e-01j, 0.8066+5.9113e-01j, 0.9801+1.9867e-01j,
0.9980+6.3203e-02j, 0.9998+1.9999e-02j, 1.0000+6.3245e-03j,
1.0000+2.0000e-03j, 1.0000+6.3246e-04j]])Docstring:
polar(abs, angle, *, out=None) -> Tensor
Constructs a complex tensor whose elements are Cartesian coordinates
corresponding to the polar coordinates with absolute value :attr:`abs` and angle
:attr:`angle`.
.. math::
\text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j
.. note::
`torch.polar` is similar to
`std::polar <https://en.cppreference.com/w/cpp/numeric/complex/polar>`_
and does not compute the polar decomposition
of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do.
The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is
infinite.
Args:
abs (Tensor): The absolute value the complex tensor. Must be float or double.
angle (Tensor): The angle of the complex tensor. Must be same dtype as
:attr:`abs`.
Keyword args:
out (Tensor): If the inputs are ``torch.float32``, must be
``torch.complex64``. If the inputs are ``torch.float64``, must be
``torch.complex128``.
Example::
>>> import numpy as np
>>> abs = torch.tensor([1, 2], dtype=torch.float64)
>>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64)
>>> z = torch.polar(abs, angle)
>>> z
tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128)
Type: builtin_function_or_methodtorch.polar converts the frequencies into complex numbers representing rotations.
It takes:
This creates complex numbers of the form cos(angle) + i*sin(angle), which are rotation operators in the complex plane.
So freqs_cis contains the actual rotation amounts for each position and dimension. When we multiply query/key vectors by these complex numbers, we're rotating them by the corresponding angles.
The name "cis" comes from the mathematical notation cis(θ) = cos(θ) + i*sin(θ).
I can understand why we are adjusting frequencies based on positions, but why are we applying it to the head dimensions as well?
Great question! We apply different frequencies to different head dimensions because we want each dimension pair to encode positional information at a different "scale" or "resolution."
Think of it like this:
This creates a hierarchy where:
By having this spectrum of frequencies across dimensions, the model can capture both local and global positional relationships simultaneously. It's like having multiple "clocks" running at different speeds - some for seconds, some for minutes, some for hours.
Without this variety, all dimensions would encode position at the same scale, which would be less expressive for capturing the full range of positional relationships in a sequence.
Okay. And we are dividing head dimension by two and calculate half for cosine and half for sine, right?
Yes, exactly! RoPE works on pairs of dimensions. Each pair gets rotated together:
So for head_dim=16, we only need 8 frequencies (for 8 pairs). The rotation is applied as:
[x0, x1] → [x0*cos(θ) - x1*sin(θ), x0*sin(θ) + x1*cos(θ)]
This is why we see torch.arange(0, dim, 2) - we're stepping by 2 to get frequencies for each pair, and [: (dim // 2)] ensures we get exactly half the dimensions worth of frequencies.
The complex representation with torch.polar elegantly handles this 2D rotation math for us.
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
It's called "rotary positional embedding" because it's embedding (encoding) positional information into the query and key vectors.
Traditional positional embeddings add position information to token embeddings before attention. RoPE is different - it embeds positional information directly into the attention mechanism by rotating the query and key vectors based on their positions.
The rotation angles encode where each token is positioned in the sequence. When two tokens at different positions have their vectors rotated by different amounts, their dot product (attention score) naturally reflects their relative distance.
So "embedding" here means "encoding positional information into" rather than the typical meaning of converting discrete tokens to vectors. It's embedding position into the geometry of the vector space through rotations.
Displayed above is a plot of absolute positional embeddings in gpt2. What can you tell me from this?
Looking at this plot of GPT-2's absolute positional embeddings, I can tell you several key insights:
The plot shows three lines representing the first three principal components of GPT-2's positional embeddings. Research has found that GPT-2's 768-dimensional positional vectors mostly live in a low-rank linear subspace, with a three-dimensional subspace explaining 90% of the variance ^1 ^2 ^3.
When these positional embeddings are plotted in 3D PCA space, they form a helix structure ^4 ^5 ^6. The three lines in your plot are the projections of this helix onto each of the three main PCA dimensions.
Key observations from the plot:
Smooth, continuous patterns: The embeddings don't jump randomly but follow smooth trajectories, suggesting the model learned structured positional representations.
Different frequencies: Each PCA dimension oscillates at different rates - some change quickly (high frequency) while others change more gradually (low frequency).
Helical structure: The dark blue end represents early positions in the sequence, while the red end represents later positions ^7 ^8 ^9, forming a spiral pattern in 3D space.
Efficient encoding: This can be thought of as the neural network reserving 3 dimensions for positional information and 765 for semantic content ^10.
This helical structure is quite different from RoPE, which explicitly uses sinusoidal functions. GPT-2 learned this structure through training, discovering an efficient way to encode position that allows for both local and global positional relationships.
Displayed above is a plot of absolute positional embeddings in gpt2. Can I make a similar plot with RoPE?
Yes, you can create a similar visualization for RoPE! Unlike GPT-2's absolute positional embeddings that are added to tokens, RoPE creates rotational patterns that you can visualize.
Here are a few ways to plot RoPE:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Rotation angles over positions for different dimensions
for i in range(min(8, freqs_cis.shape[1])):
axes[0,0].plot(torch.angle(freqs_cis[:, i]), label=f'dim_{i}', marker='o')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position')
axes[0,0].legend()
axes[0,0].grid(True)
# Plot 2: Real and imaginary parts
axes[0,1].plot(freqs_cis[:, 0].real, label='Real (cos)', marker='o')
axes[0,1].plot(freqs_cis[:, 0].imag, label='Imag (sin)', marker='s')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)
# Plot 3: 2D rotation visualization
theta = torch.angle(freqs_cis[:, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-')
for i, pos in enumerate(range(len(theta))):
axes[1,0].annotate(f'pos_{i}', (x_pos[i], y_pos[i]))
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')
# Plot 4: Heatmap of all rotation angles
angles = torch.angle(freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])
plt.tight_layout()
plt.show()
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Rotation angles over positions for different dimensions
for i in range(min(8, freqs_cis.shape[1])):
axes[0,0].plot(torch.angle(freqs_cis[:, i]), label=f'dim_{i}', marker='o')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position')
axes[0,0].legend()
axes[0,0].grid(True)
# Plot 2: Real and imaginary parts
axes[0,1].plot(freqs_cis[:, 0].real, label='Real (cos)', marker='o')
axes[0,1].plot(freqs_cis[:, 0].imag, label='Imag (sin)', marker='s')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)
# Plot 3: 2D rotation visualization
theta = torch.angle(freqs_cis[:, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-')
for i, pos in enumerate(range(len(theta))):
axes[1,0].annotate(f'pos_{i}', (x_pos[i], y_pos[i]))
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')
# Plot 4: Heatmap of all rotation angles
angles = torch.angle(freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])
plt.tight_layout()
plt.show()
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
1.0000e-03, 3.1623e-04],
[2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03,
2.0000e-03, 6.3246e-04]])Let's try the plotting wtih more practical example. How about we do the same of sequence length and embedding as gpt2? Create variables different from the ones we tried so we do not overwrite them.
GPT-2 has 768 dimensional states and 12 attention heads ^1 ^2 ^3. The head dimension is calculated as:
head_dim = embedding_dim / num_heads = 768 / 12 = 64
So each attention head gets 64 dimensions to work with. This is a design choice that divides the full embedding dimension evenly across all attention heads, allowing each head to focus on different aspects of the relationships between tokens.
gpt2_context_len = 1024
gpt2_head_dim = 64
gpt2_freqs_cis = precompute_freqs_cis(gpt2_head_dim, gpt2_context_len)
gpt2_freqs_cis.shape
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
for i in range(min(8, gpt2_freqs_cis.shape[1])):
axes[0,0].plot(torch.angle(gpt2_freqs_cis[:, i]), label=f'dim_{i}')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position (GPT-2 scale)')
axes[0,0].legend()
axes[0,0].grid(True)
axes[0,1].plot(gpt2_freqs_cis[:, 0].real, label='Real (cos)')
axes[0,1].plot(gpt2_freqs_cis[:, 0].imag, label='Imag (sin)')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)
theta = torch.angle(gpt2_freqs_cis[:100, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-', markersize=2)
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern (first 100 positions)')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')
angles = torch.angle(gpt2_freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])
plt.tight_layout()
plt.show()
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
for i in range(min(8, gpt2_freqs_cis.shape[1])):
axes[0,0].plot(torch.angle(gpt2_freqs_cis[:, i]), label=f'dim_{i}')
axes[0,0].set_xlabel('Position')
axes[0,0].set_ylabel('Rotation angle (radians)')
axes[0,0].set_title('RoPE: Rotation angles by position (GPT-2 scale)')
axes[0,0].legend()
axes[0,0].grid(True)
axes[0,1].plot(gpt2_freqs_cis[:, 0].real, label='Real (cos)')
axes[0,1].plot(gpt2_freqs_cis[:, 0].imag, label='Imag (sin)')
axes[0,1].set_xlabel('Position')
axes[0,1].set_ylabel('Value')
axes[0,1].set_title('RoPE: First dimension pair (cos/sin)')
axes[0,1].legend()
axes[0,1].grid(True)
theta = torch.angle(gpt2_freqs_cis[:100, 0])
x_pos = torch.cos(theta)
y_pos = torch.sin(theta)
axes[1,0].plot(x_pos, y_pos, 'o-', markersize=2)
axes[1,0].set_xlabel('cos(θ)')
axes[1,0].set_ylabel('sin(θ)')
axes[1,0].set_title('RoPE: 2D rotation pattern (first 100 positions)')
axes[1,0].grid(True)
axes[1,0].set_aspect('equal')
angles = torch.angle(gpt2_freqs_cis)
im = axes[1,1].imshow(angles.T, aspect='auto', cmap='hsv')
axes[1,1].set_xlabel('Position')
axes[1,1].set_ylabel('Dimension pair')
axes[1,1].set_title('RoPE: All rotation angles')
plt.colorbar(im, ax=axes[1,1])
plt.tight_layout()
plt.show()
Looking at these plots, here's what each one shows:
Top Left - Rotation angles by position: Each colored line shows how the rotation angle increases for different dimension pairs as position increases. Higher frequency dimensions (like dim_0) rotate much faster and create dense oscillations, while lower frequency dimensions rotate more slowly.
Top Right - First dimension pair (cos/sin): This shows the cosine and sine values for the fastest-rotating dimension. You can see the rapid oscillation - this dimension completes many full rotations across the 1024 positions.
Bottom Left - 2D rotation pattern: This shows the first 100 positions plotted as points on a unit circle. Each point represents where that position "lands" after rotation. The spiral pattern shows how consecutive positions are rotated by small increments, creating a smooth progression around the circle.
Bottom Right - All rotation angles heatmap: This is the most informative. The x-axis is position (0-1024), y-axis is dimension pair (0-31). The colors represent rotation angles. You can see:
Think of it like having 32 different "clocks" - some tick every second (high freq), others tick every hour (low freq). Together they can uniquely identify any position in the sequence.
Thanks for the explanation, but it is very hard to visualize what top plots are doing. Can you help?
Let me create some interactive visualizations to help you understand what the top plots are showing:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
pos = 0
dim_pairs = [0, 1, 2, 3]
colors = ['red', 'blue', 'green', 'orange']
for i, dim_pair in enumerate(dim_pairs):
angles = torch.angle(gpt2_freqs_cis[:100, dim_pair])
ax1.plot(angles, label=f'dim_pair_{dim_pair}', color=colors[i], linewidth=2)
ax1.set_xlabel('Position')
ax1.set_ylabel('Rotation angle (radians)')
ax1.set_title('How rotation angles grow with position')
ax1.legend()
ax1.grid(True)
for i, dim_pair in enumerate(dim_pairs):
freqs = 1.0 / (10000 ** (torch.arange(0, 64, 2)[dim_pair].float() / 64))
ax2.bar(i, freqs, color=colors[i], label=f'dim_pair_{dim_pair}')
ax2.set_xlabel('Dimension pair')
ax2.set_ylabel('Frequency (rotations per position)')
ax2.set_title('Base frequencies for each dimension pair')
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.show()
print("Key insights:")
print("- Left plot: Shows how rotation angles accumulate linearly with position")
print("- Right plot: Shows base frequencies - higher dims rotate much slower")
print("- dim_pair_0 completes ~1 rotation every position")
print("- dim_pair_3 completes ~0.03 rotations every position")
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
pos = 0
dim_pairs = [0, 1, 2, 3]
colors = ['red', 'blue', 'green', 'orange']
for i, dim_pair in enumerate(dim_pairs):
angles = torch.angle(gpt2_freqs_cis[:100, dim_pair])
ax1.plot(angles, label=f'dim_pair_{dim_pair}', color=colors[i], linewidth=2)
ax1.set_xlabel('Position')
ax1.set_ylabel('Rotation angle (radians)')
ax1.set_title('How rotation angles grow with position')
ax1.legend()
ax1.grid(True)
for i, dim_pair in enumerate(dim_pairs):
freqs = 1.0 / (10000 ** (torch.arange(0, 64, 2)[dim_pair].float() / 64))
ax2.bar(i, freqs, color=colors[i], label=f'dim_pair_{dim_pair}')
ax2.set_xlabel('Dimension pair')
ax2.set_ylabel('Frequency (rotations per position)')
ax2.set_title('Base frequencies for each dimension pair')
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.show()
print("Key insights:")
print("- Left plot: Shows how rotation angles accumulate linearly with position")
print("- Right plot: Shows base frequencies - higher dims rotate much slower")
print("- dim_pair_0 completes ~1 rotation every position")
print("- dim_pair_3 completes ~0.03 rotations every position")
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xq_.shape
Docstring:
view_as_complex(input) -> Tensor
Returns a view of :attr:`input` as a complex tensor. For an input complex
tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a
new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last
dimension of the input tensor is expected to represent the real and imaginary
components of complex numbers.
.. warning::
:func:`view_as_complex` is only supported for tensors with
:class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is
expected to have the last dimension of :attr:`size` 2. In addition, the
tensor must have a `stride` of 1 for its last dimension. The strides of all
other dimensions must be even numbers.
Args:
input (Tensor): the input tensor.
Example::
>>> x=torch.randn(4, 2)
>>> x
tensor([[ 1.6116, -0.5772],
[-1.4606, -0.9120],
[ 0.0786, -1.7497],
[-0.6561, -1.6623]])
>>> torch.view_as_complex(x)
tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)])
Type: builtin_function_or_methodxk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xk_.shape
The multiplication xq_ * freqs_cis is complex number multiplication, which performs the rotation.
When you multiply two complex numbers:
(a + bi) * (cos(θ) + i*sin(θ)) = (a*cos(θ) - b*sin(θ)) + i*(a*sin(θ) + b*cos(θ))
This is exactly the 2D rotation formula! So each complex number in xq_ gets rotated by the corresponding angle in freqs_cis.
The broadcasting works because:
xq_ has shape [1, 3, 2, 8] (batch, seq_len, heads, head_dim_pairs)freqs_cis has shape [1, 3, 1, 8] (batch, seq_len, 1, head_dim_pairs)The 1 in the heads dimension broadcasts across all 2 heads, so each head gets the same rotational encoding for the same position and dimension pair.
Each position gets its own rotation angles, and each dimension pair gets its own frequency, but all heads share the same positional encoding.
You can get even indexed elements using slicing with a step of 2, starting from index 0:
tensor[..., ::2]
This selects every 2nd element starting from index 0 (so indices 0, 2, 4, 6, etc.) along the last dimension.
For other dimensions, you can specify the dimension:
tensor[::2] # even indices along first dimension
tensor[:, ::2] # even indices along second dimension
I have two following tensors:
tensor([[-1.2295, -0.2887, 0.9687, -0.6554, 1.3035, 0.1295, 0.1323, -0.8838],
[ 0.5342, 0.5383, 0.8112, -0.4330, 0.5602, -0.9264, -0.1925, -1.5116],
[ 1.4182, 0.3159, -1.0963, 0.1530, 0.6830, -0.2398, 0.0732, 0.6918],
[ 0.6567, 0.6113, -2.1482, 1.1618, 0.4768, -0.5982, 2.5378, 1.1881]])
tensor([[-0.6163, 1.1063, -1.7300, -0.1095, -0.7442, -0.8126, 0.2128, 0.2878],
[-0.7322, -1.1714, 2.1736, 0.6815, 0.6597, -0.5675, -0.1627, -2.4516],
[ 0.7278, -0.1520, 1.5938, 0.4412, -0.9143, -0.7768, 0.5236, 0.2211],
[ 0.4192, -0.1826, 0.6669, -1.7079, 1.9377, -0.7297, -1.2567, 0.1236]])
I want to combine them like this:
tensor([[-1.2295, -0.6163, -0.2887, 1.1063, 0.9687, -1.7300, -0.6554, -0.1095,
1.3035, -0.7442, 0.1295, -0.8126, 0.1323, 0.2128, -0.8838, 0.2878],
[ 0.5342, -0.7322, 0.5383, -1.1714, 0.8112, 2.1736, -0.4330, 0.6815,
0.5602, 0.6597, -0.9264, -0.5675, -0.1925, -0.1627, -1.5116, -2.4516],
[ 1.4182, 0.7278, 0.3159, -0.1520, -1.0963, 1.5938, 0.1530, 0.4412,
0.6830, -0.9143, -0.2398, -0.7768, 0.0732, 0.5236, 0.6918, 0.2211],
[ 0.6567, 0.4192, 0.6113, -0.1826, -2.1482, 0.6669, 1.1618, -1.7079,
0.4768, 1.9377, -0.5982, -0.7297, 2.5378, -1.2567, 1.1881, 0.1236]])
How do I do this?
Docstring:
view_as_real(input) -> Tensor
Returns a view of :attr:`input` as a real tensor. For an input complex tensor of
:attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new
real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2
represents the real and imaginary components of complex numbers.
.. warning::
:func:`view_as_real` is only supported for tensors with ``complex dtypes``.
Args:
input (Tensor): the input tensor.
Example::
>>> x=torch.randn(4, dtype=torch.cfloat)
>>> x
tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)])
>>> torch.view_as_real(x)
tensor([[ 0.4737, -0.3839],
[-0.2098, -0.6699],
[ 0.3470, -0.9451],
[-0.5174, -1.3136]])
Type: builtin_function_or_methodfrom gpt to llama
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
assert head_dim % 2 == 0, "Embedding dimension must be even"
# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
# Generate position indices
positions = torch.arange(context_length)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
# angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
# angles = torch.repeat_interleave(angles, 2, dim=1)
# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos.shape, sin.shape
inv_freq = 1.0 / (10_000 ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
positions = torch.arange(context_len)
angles = positions[:, None] * inv_freq[None, :]
angles.shape
llama_freqs = torch.outer(t, 1.0 / (theta_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).float()
llama_freqs.shape
Instead of transforming angles/frequencies into complex space, Rashka's implementation directly calculates cos and sin. Rashka's implementation uses torch.repeat_interleave to repeat the angles, which is equivalent to reshape_for_broadcast. However, to compare with llama's cis, we will skip the repeating part to make it easier to compare.
llama_freqs_cis = torch.polar(torch.ones_like(llama_freqs), llama_freqs)
llama_freqs_cis = torch.repeat_interleave(llama_freqs_cis, 2, dim=1)
llama_freqs_cis.shape
def compute_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2 :] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
return x_rotated.to(dtype=x.dtype)
# queries and keys have shape: (batch_size, context_len, num_heads, head_dim)
queries.shape, keys.shape
# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos.shape, sin.shape
# Convert to (batch_size, num_heads, context_len, head_dim)
x = queries.transpose(1, 2)
x.shape
cos = cos[:context_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, context_len, head_dim)
sin = sin[:context_len, :].unsqueeze(0).unsqueeze(0)
cos.shape, sin.shape
assert torch.allclose(torch.cat([(x1 * cos), (x2 * sin)], dim=-1).transpose(1,2), xq_out)
tensor([[[[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196,
-0.3792],
[-0.0770, -1.0205, -0.1690, 0.9178, 1.5810, 1.3010, 1.2753,
-0.2010]],
[[-0.7159, 0.1696, -2.1231, 1.0518, -0.3885, -0.9343, -0.4991,
-1.0867],
[ 0.4738, 1.5417, -1.4706, 1.1325, -1.2202, 1.3138, 1.0533,
0.1388]],
[[-0.2142, 0.8015, -0.2536, -1.0804, -0.0444, 1.6235, -2.3229,
1.0878],
[-0.1062, -0.4433, 0.9842, 0.8256, -0.3947, 0.4892, -0.2168,
-1.7472]]]])tensor([[[[ 0.3374, -0.3035, 0.3486, -0.2196, 0.7671, 0.6984, 0.1794,
0.4954],
[-0.8660, -2.3552, -0.2933, -0.4645, 0.8650, 0.6272, 0.0984,
0.2664],
[-1.1178, 0.4312, -0.3660, -2.3870, 0.6576, -0.9482, -0.1529,
0.4412]],
[[-0.0770, -0.1690, 1.5810, 1.2753, 0.4965, 0.9666, -1.1589,
-0.6315],
[-0.8912, -1.7570, -1.3454, 1.0484, 2.2553, -0.2833, -0.6588,
0.1838],
[ 0.3935, 0.3210, -0.4841, -0.1060, -1.5807, 0.9077, -0.5937,
0.6238]]]])(torch.Size([1, 2, 3, 8]),
tensor([[[[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196,
-0.3792],
[-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991,
-1.0867],
[ 0.5146, 0.9938, -0.2587, -1.0826, -0.0444, 1.6236, -2.3229,
1.0878]],
[[-0.0770, -1.0205, -0.1690, 0.9178, 1.5810, 1.3010, 1.2753,
-0.2010],
[ 0.8768, 1.6221, -1.4779, 1.1331, -1.2203, 1.3139, 1.0533,
0.1388],
[ 0.2553, -0.5496, 1.0042, 0.8272, -0.3948, 0.4892, -0.2168,
-1.7472]]]]))