Part II: Self-Attentive Sequential Recommendation

Transformers for Next-Item Prediction

Introduction

This notebook implements SASRec (Self-Attentive Sequential Recommendation), a transformer-based model for sequential recommendation.

Key Paper:

  • Kang & McAuley (2018) introduced SASRec, applying self-attention to sequential recommendation

What we’ll build:

  1. Sequential data preprocessing
  2. SASRec model architecture
  3. Training loop
  4. Evaluation and visualization
Show code
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
import torch.nn as nn
from IPython.display import Markdown, display
from plotnine import *
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from torch.utils.data import DataLoader

from recsys_genai.data_utils import SequenceDataset, get_user_sequences, load_movielens
from recsys_genai.metrics import ndcg_at_k, recall_at_k
from recsys_genai.sasrec import SASRec, train_sasrec
Show code
theme_set(
    theme_minimal()
    + theme(
        plot_title=element_text(weight="bold", size=14),
        axis_title=element_text(size=12),
        figure_size=(8, 6),
    )
)

torch.manual_seed(42)
np.random.seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cuda
Show code
class SASRecSettings(BaseSettings):
    """Configuration for SASRec training.

    Can be overridden with environment variables using SASREC_ prefix.
    Example: SASREC_NUM_EPOCHS=10
    """

    model_config = SettingsConfigDict(env_prefix="SASREC_")

    # Data preprocessing
    min_user_ratings: int = Field(default=20, description="Minimum number of ratings per user")
    rating_threshold: float = Field(default=4.0, description="Minimum rating for positive feedback")
    min_sequence_length: int = Field(default=10, description="Minimum sequence length")
    max_sequence_length: int = Field(default=50, description="Maximum sequence length")

    # Model architecture
    num_blocks: int = Field(default=2, description="Number of transformer blocks")
    num_heads: int = Field(default=2, description="Number of attention heads")
    hidden_size: int = Field(default=64, description="Hidden dimension size")
    dropout: float = Field(default=0.2, description="Dropout rate")

    # Training parameters
    batch_size: int = Field(default=128, description="Training batch size")
    num_epochs: int = Field(default=5, description="Number of training epochs")
    learning_rate: float = Field(default=0.001, description="Learning rate for optimizer")

    # Evaluation
    eval_k: int = Field(default=10, description="K for Recall@K and NDCG@K metrics")
    eval_sample_size: int = Field(default=1000, description="Number of test users to evaluate")


settings = SASRecSettings()

Prepare Sequential Data

Load and Filter Data

Show code
movies, ratings, tags, links = load_movielens("../data")

# Filter to users with enough ratings
user_counts = ratings.group_by("user_id").agg(pl.len().alias("count"))
active_users = user_counts.filter(pl.col("count") >= settings.min_user_ratings)["user_id"].to_list()

# Filter ratings
filtered_ratings = ratings.filter(
    pl.col("user_id").is_in(active_users) & (pl.col("rating") >= settings.rating_threshold)  # Implicit positive feedback
)

load_stats = f"""
**Loaded Data:**

- Active users (≥{settings.min_user_ratings} ratings): {len(active_users):,}
- Positive interactions: {len(filtered_ratings):,}
"""
display(Markdown(load_stats))

Loaded Data:

  • Active users (≥20 ratings): 204,443
  • Positive interactions: 16,182,634

Extract User Sequences

Show code
# Get sequences ordered by timestamp
sequences = get_user_sequences(filtered_ratings, max_len=settings.max_sequence_length, min_len=settings.min_sequence_length)

example_user = list(sequences.keys())[0]
sequence_stats = f"""
**User Sequences:**

- Number of sequences: {len(sequences)}
- Example sequence (user {example_user}): {sequences[example_user][:10]}
"""
display(Markdown(sequence_stats))

User Sequences:

  • Number of sequences: 194848
  • Example sequence (user 214082): [780, 6, 733, 104, 788, 1061, 344, 231, 296, 592]

Create Item Vocabulary

Show code
# Map movie IDs to indices (0 is reserved for padding)
all_items = set()
for seq in sequences.values():
    all_items.update(seq)

item_to_idx = {item_id: idx + 1 for idx, item_id in enumerate(sorted(all_items))}
idx_to_item = {idx: item_id for item_id, idx in item_to_idx.items()}
idx_to_item[0] = 0  # Padding

# Remap sequences to indices
indexed_sequences = {
    uid: [item_to_idx[item_id] for item_id in seq] for uid, seq in sequences.items()
}

vocab_stats = f"""
**Item Vocabulary:**

- Vocabulary size: {len(item_to_idx)} items
"""
display(Markdown(vocab_stats))

Item Vocabulary:

  • Vocabulary size: 37716 items

Train/Test Split

Show code
# Split sequences into train/test
# Test: last item of each sequence
# Train: all but last item

train_sequences = {}
test_ground_truth = {}

for uid, seq in indexed_sequences.items():
    if len(seq) < 2:
        continue

    train_sequences[uid] = seq[:-1]  # All but last
    test_ground_truth[uid] = {seq[-1]}  # Last item

split_stats = f"""
**Train/Test Split:**

- Train sequences: {len(train_sequences)}
- Test sequences: {len(test_ground_truth)}
"""
display(Markdown(split_stats))

Train/Test Split:

  • Train sequences: 194848
  • Test sequences: 194848

Create DataLoader

Show code
train_dataset = SequenceDataset(train_sequences, max_len=settings.max_sequence_length)
train_loader = DataLoader(train_dataset, batch_size=settings.batch_size, shuffle=True, num_workers=0)

# Inspect a batch
sample_batch = next(iter(train_loader))

dataloader_stats = f"""
**DataLoader:**

- Training batches: {len(train_loader)}
- Batch sequence shape: {sample_batch[0].shape}
- Batch target shape: {sample_batch[1].shape}
"""
display(Markdown(dataloader_stats))

DataLoader:

  • Training batches: 1523
  • Batch sequence shape: torch.Size([128, 50])
  • Batch target shape: torch.Size([128])

Build SASRec Model

Model Architecture

Show code
num_items = len(item_to_idx)

model = SASRec(
    num_items=num_items,
    max_len=settings.max_sequence_length,
    num_blocks=settings.num_blocks,
    num_heads=settings.num_heads,
    hidden_size=settings.hidden_size,
    dropout=settings.dropout
)

model_stats = f"""
**Model Architecture:**

- Model parameters: {sum(p.numel() for p in model.parameters()):,}
"""
display(Markdown(model_stats))
print(model)

Model Architecture:

  • Model parameters: 2,567,168
SASRec(
  (item_emb): Embedding(37717, 64, padding_idx=0)
  (pos_emb): Embedding(50, 64)
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.2, inplace=False)
        (3): Linear(in_features=256, out_features=64, bias=True)
        (4): Dropout(p=0.2, inplace=False)
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)

Train SASRec

Objective: Predict the next item given a sequence

\[ \mathcal{L} = -\log P(i_{t+1} | i_1, i_2, ..., i_t) \]

Show code
# Train model
losses = train_sasrec(model=model, dataloader=train_loader, num_epochs=settings.num_epochs, lr=settings.learning_rate, device=device)
Epoch 1/20, Loss: 10.6246
Epoch 2/20, Loss: 8.0213
Epoch 3/20, Loss: 7.8421
Epoch 4/20, Loss: 7.6782
Epoch 5/20, Loss: 7.5332
Epoch 6/20, Loss: 7.4257
Epoch 7/20, Loss: 7.3573
Epoch 8/20, Loss: 7.3358
Epoch 9/20, Loss: 7.2718
Epoch 10/20, Loss: 7.2367
Epoch 11/20, Loss: 7.1992
Epoch 12/20, Loss: 7.1786
Epoch 13/20, Loss: 7.1220
Epoch 14/20, Loss: 7.0757
Epoch 15/20, Loss: 7.0378
Epoch 16/20, Loss: 6.9982
Epoch 17/20, Loss: 6.9525
Epoch 18/20, Loss: 6.9173
Epoch 19/20, Loss: 6.8819
Epoch 20/20, Loss: 6.8481

Plot Training Loss

Show code
losses_df = pl.DataFrame({"epoch": list(range(1, len(losses) + 1)), "loss": losses})

(
    ggplot(losses_df, aes(x="epoch", y="loss"))
    + geom_line(size=1)
    + geom_point(size=3)
    + labs(title="SASRec Training Loss", x="Epoch", y="Loss")
)

Generate Recommendations

Predict for Sample User

Let’s see what SASRec recommends for a randomly sampled user!

Show code
# Sample a random user from active_users who has training sequences
available_users = [uid for uid in active_users if uid in train_sequences]
sample_user_id = np.random.choice(available_users)

sample_seq = train_sequences[sample_user_id]

# Pad/truncate to max_len
if len(sample_seq) > settings.max_sequence_length:
    sample_seq = sample_seq[-settings.max_sequence_length:]
else:
    sample_seq = [0] * (settings.max_sequence_length - len(sample_seq)) + sample_seq

# Convert to tensor
seq_tensor = torch.tensor([sample_seq], dtype=torch.long).to(device)

# Predict
top_items, top_scores = model.predict_next(seq_tensor, k=10)

# Map back to movie IDs
recommendations = [idx_to_item[idx.item()] for idx in top_items[0]]

# Display
display(Markdown(f"**User {sample_user_id}'s Recent History:**"))
recent_items = [idx_to_item[idx] for idx in sample_seq[-5:] if idx > 0]
recent_movies = movies.filter(pl.col("movie_id").is_in(recent_items))
display(recent_movies.select(["title", "genres"]))

display(Markdown(f"**SASRec Recommendations for User {sample_user_id}:**"))
rec_movies = movies.filter(pl.col("movie_id").is_in(recommendations))
display(rec_movies.select(["title", "genres"]))

User 59146’s Recent History:

shape: (5, 2)
title genres
str list[str]
"Shawshank Redemption, The (199… ["Crime", "Drama"]
"Fargo (1996)" ["Comedy", "Crime", … "Thriller"]
"Dr. Strangelove or: How I Lear… ["Comedy", "War"]
"Psycho (1960)" ["Crime", "Horror"]
"Chinatown (1974)" ["Crime", "Film-Noir", … "Thriller"]

SASRec Recommendations for User 59146:

shape: (10, 2)
title genres
str list[str]
"Silence of the Lambs, The (199… ["Crime", "Horror", "Thriller"]
"Fargo (1996)" ["Comedy", "Crime", … "Thriller"]
"Dr. Strangelove or: How I Lear… ["Comedy", "War"]
"Godfather, The (1972)" ["Crime", "Drama"]
"Rear Window (1954)" ["Mystery", "Thriller"]
"Casablanca (1942)" ["Drama", "Romance"]
"One Flew Over the Cuckoo's Nes… ["Drama"]
"Godfather: Part II, The (1974)" ["Crime", "Drama"]
"Graduate, The (1967)" ["Comedy", "Drama", "Romance"]
"American Beauty (1999)" ["Drama", "Romance"]

Evaluate Model

Evaluate on test set using Recall@K and NDCG@K.

Show code
model.eval()

predictions = {}
sample_test_users = list(test_ground_truth.keys())[:settings.eval_sample_size]  # Sample for speed

with torch.no_grad():
    for uid in sample_test_users:
        if uid not in train_sequences:
            continue

        seq = train_sequences[uid]

        # Pad/truncate
        if len(seq) > settings.max_sequence_length:
            seq = seq[-settings.max_sequence_length:]
        else:
            seq = [0] * (settings.max_sequence_length - len(seq)) + seq

        seq_tensor = torch.tensor([seq], dtype=torch.long).to(device)

        # Predict top-20
        top_items, _ = model.predict_next(seq_tensor, k=20)
        predictions[uid] = top_items[0].cpu().tolist()

# Calculate metrics
recall_scores = []
ndcg_scores = []

for uid in predictions:
    if uid in test_ground_truth:
        preds = predictions[uid]
        targets = test_ground_truth[uid]

        recall_scores.append(recall_at_k(preds, targets, k=settings.eval_k))
        ndcg_scores.append(ndcg_at_k(preds, targets, k=settings.eval_k))

eval_results = f"""
**Evaluation Results:**

- Recall@{settings.eval_k}: {np.mean(recall_scores):.4f}
- NDCG@{settings.eval_k}: {np.mean(ndcg_scores):.4f}
"""
display(Markdown(eval_results))

Evaluation Results:

  • Recall@10: 0.0870
  • NDCG@10: 0.0487

Results Interpretation

Recall@10 measures: “What fraction of relevant items appear in top-10?”

NDCG@10 measures: “How well are relevant items ranked?” (position-aware)

Note

These metrics are typically lower than MF/EASE on very sparse data, but SASRec excels when: - Users have rich interaction histories - Temporal patterns matter - Short-term intent dominates

Analyze Attention Patterns

Let’s visualize what the model learned!

Sample User’s Movie Sequence

First, let’s look at the viewing history of a randomly sampled user.

Show code
from recsys_genai.notebook_utils import tmdb_images

# Sample a random user for attention visualization
viz_user_id = np.random.choice(available_users)
viz_seq = train_sequences[viz_user_id]

# Load posters dataset
posters_raw = pl.read_parquet("../data/shared/posters.parquet")

# Join to get movie info with posters (via links for tmdb_id)
viz_movie_ids = [idx_to_item[idx] for idx in viz_seq if idx > 0]
viz_movies = (
    pl.DataFrame({"movie_id": viz_movie_ids})
    .join(movies, on="movie_id", maintain_order="left")
    .join(links.select(["movie_id", "tmdb_id"]), on="movie_id", maintain_order="left")
    .join(posters_raw, on="tmdb_id", maintain_order="left")
)

# Display user sequence info
display(
    Markdown(f"""
**Sampled User ID:** {viz_user_id}

**Sequence Length:** {len(viz_seq)} items

**Unique Movies:** {len(viz_movie_ids)} (excluding padding)
""")
)

recent_movies = viz_movies.tail(20)

# Display posters for movies that have them
poster_paths = [p for p in recent_movies["poster_path"].to_list() if p is not None]
display(Markdown(f"\n**Movie Posters:**"))
display(tmdb_images(poster_paths))

Sampled User ID: 124051

Sequence Length: 49 items

Unique Movies: 49 (excluding padding)

Movie Posters:

Extract Attention Weights

Now let’s extract attention weights from the model for this user’s sequence.

Show code
if len(viz_seq) > 20:
    viz_seq = viz_seq[-20:]  # Use last 20 for visualization

seq_padded = [0] * (20 - len(viz_seq)) + viz_seq
seq_tensor = torch.tensor([seq_padded], dtype=torch.long).to(device)

model.eval()
with torch.no_grad():
    _, attn_weights_list = model.forward_with_attention(seq_tensor)

# Get movie titles for labels
movie_ids = [idx_to_item[idx] for idx in seq_padded]
valid_ids = [mid for mid in movie_ids if mid > 0]

if valid_ids:
    movie_titles = movies.filter(pl.col("movie_id").is_in(valid_ids))
    title_map = {row["movie_id"]: row["title"][:20] for row in movie_titles.to_dicts()}
    labels = [title_map.get(mid, "PAD") for mid in movie_ids]
else:
    labels = ["PAD" if mid == 0 else str(mid) for mid in movie_ids]

# Plot attention weights from all blocks and all heads
for block_idx, block_attn in enumerate(attn_weights_list):
    # block_attn shape: (batch_size, num_heads, seq_len, seq_len)
    num_heads = block_attn.shape[1]

    display(Markdown(f"---\n## Block {block_idx + 1}\n---"))

    for head_idx in range(num_heads):
        attn_matrix = block_attn[0, head_idx].cpu().numpy()  # First sample, specific head

        # Prepare data for plotnine
        attn_data = []
        for i in range(attn_matrix.shape[0]):
            for j in range(attn_matrix.shape[1]):
                attn_data.append({"query_pos": i, "key_pos": j, "attention": attn_matrix[i, j]})

        attn_df = pl.DataFrame(attn_data)

        # Create heatmap with plotnine
        plot = (
            ggplot(attn_df, aes(x="key_pos", y="query_pos", fill="attention"))
            + geom_tile()
            + scale_fill_gradient(low="#440154", high="#FDE724", name="Attention\nWeight")
            + scale_x_continuous(breaks=list(range(20)), labels=labels)
            + scale_y_reverse(breaks=list(range(20)), labels=labels)
            + labs(
                title=f"Block {block_idx + 1}, Head {head_idx + 1} - User {viz_user_id}",
                x="Key Position",
                y="Query Position",
            )
            + theme(
                axis_text_x=element_text(rotation=90, hjust=1, size=8),
                axis_text_y=element_text(size=8),
            )
        )

        display(plot)

Interpretation:

  • Brighter cells = stronger attention
  • Each row shows what a query position attends to
  • Causal mask enforces no future peeking

Key Takeaways

  1. SASRec models temporal dynamics - Order matters!
  2. Self-attention learns which past items are relevant for prediction
  3. Causal masking ensures no future leakage during training

Connection to LLMs: SASRec uses the same transformer architecture as GPT!

  • GPT predicts next word
  • SASRec predicts next item
  • Both use causal self-attention

References

Kang, W.-C., & McAuley, J. (2018). Self-attentive sequential recommendation. 2018 IEEE International Conference on Data Mining (ICDM), 197–206. https://doi.org/10.1109/ICDM.2018.00035