Part IV.a: Slate Generation with REINFORCE

Encoder-Decoder Transformer for Diversity-Aware Recommendations

Introduction

In this notebook, we’ll build a slate generation model using an encoder-decoder transformer trained with REINFORCE (Williams, 1992), inspired by Bello et al. (2018).

Key Concepts:

  1. Problem: Generate diverse slates of 5 items similar to a seed item
  2. Objective: Maximize relevance to seed while encouraging diversity
  3. Method: Encoder-decoder transformer with policy gradient training (REINFORCE)
  4. Similarities: Use EASE item-item similarities instead of cosine similarities on embeddings

Reward Function:

\[ R(S \mid \text{seed}) = \frac{1}{|S|} \sum_{i \in S} \text{sim}_{\text{EASE}}(i, \text{seed}) - \lambda \cdot \frac{1}{|S|(|S|-1)} \sum_{i,j \in S, i \neq j} \text{sim}_{\text{EASE}}(i, j) \]

Where:

  • First term: average EASE similarity to seed (relevance)
  • Second term: average pairwise EASE similarity (diversity penalty)
  • \(\lambda\): diversity weight (e.g., 0.5)
  • \(\text{sim}_{\text{EASE}}(i, j)\): EASE learned item-item similarity from matrix \(B\)

Prerequisites: Run 01b_foundations.qmd to generate the EASE similarity matrix!

Show code
from pathlib import Path

import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import Markdown, display
from plotnine import *
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from recsys_genai.data_utils import load_movielens
from recsys_genai.notebook_utils import tmdb_images
Show code
theme_set(
    theme_minimal()
    + theme(
        plot_title=element_text(weight="bold", size=14),
        axis_title=element_text(size=12),
        figure_size=(8, 6),
    )
)

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda

Load Data and Embeddings

We’ll load the MovieLens data and pre-trained item embeddings from the Matrix Factorization model (Part I.b).

Show code
# Load MovieLens data
movies, ratings, tags, links = load_movielens("../data")
posters_raw = pl.read_parquet("../data/shared/posters.parquet")

# Create a movie_id to poster mapping by joining through links
posters = (
    links.select(["movie_id", "tmdb_id"])
    .join(posters_raw, on="tmdb_id", how="inner")
    .select(["movie_id", "poster_path"])
)

data_stats = f"""
**Data Loaded:**

- Movies: {len(movies):,}
- Ratings: {len(ratings):,}
"""
display(Markdown(data_stats))

Data Loaded:

  • Movies: 86,537
  • Ratings: 33,832,162
Show code
# Load EASE matrix from Part I.b
generated_data_dir = Path("..") / "data" / "generated"
ease_file = generated_data_dir / "ease_matrix.npz"

# Check if EASE matrix exists
if not ease_file.exists():
    raise FileNotFoundError(
        f"EASE matrix not found. Please run notebook 01b_foundations.qmd first!\n"
        f"Expected file:\n"
        f"  - {ease_file}\n"
        f"\n"
        f"Make sure to run 01b_foundations.qmd to generate the EASE matrix."
    )

# Load EASE matrix and movie IDs
data = np.load(ease_file)
ease_B = data["B"]  # Item-item similarity matrix
item_ids = data["movie_ids"]

# Create item_to_idx mapping
item_to_idx = {int(movie_id): idx for idx, movie_id in enumerate(item_ids)}

# Convert EASE matrix to PyTorch tensor
ease_B_tensor = torch.FloatTensor(ease_B).to(device)

ease_stats = f"""
**Loaded EASE matrix:**

- Shape: {ease_B.shape}
- Number of items: {len(item_to_idx):,}

This is the item-item similarity matrix learned by EASE.
"""
display(Markdown(ease_stats))

Loaded EASE matrix:

  • Shape: (5000, 5000)
  • Number of items: 5,000

This is the item-item similarity matrix learned by EASE.

Slate Generation Model

Now we’ll implement an encoder-decoder transformer for slate generation.

Show code
class SlateGeneratorTransformer(nn.Module):
    """Encoder-Decoder Transformer for Slate Generation.

    Encoder: Process seed item + candidate items (using trainable embeddings)
    Decoder: Autoregressively generate slate positions conditioned on seed
    """

    def __init__(
        self,
        num_items: int,
        embedding_dim: int = 64,
        num_heads: int = 4,
        num_encoder_layers: int = 2,
        num_decoder_layers: int = 2,
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        max_slate_size: int = 5,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.max_slate_size = max_slate_size

        # Trainable item embeddings (replacing content embeddings)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        nn.init.normal_(self.item_embedding.weight, mean=0.0, std=0.1)

        # Seed embedding to mark the seed item in encoder input
        # This gets added to the seed item embedding to distinguish it from candidates
        self.seed_embedding = nn.Parameter(torch.zeros(embedding_dim))
        nn.init.normal_(self.seed_embedding, mean=0.0, std=0.1)

        # Positional encoding for decoder
        self.pos_embedding = nn.Embedding(max_slate_size, embedding_dim)

        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Output projection to candidate scores
        self.output_proj = nn.Linear(embedding_dim, 1)

    def forward(self, seed_indices, candidate_indices, slate_indices=None, slate_mask=None):
        """
        Args:
            seed_indices: [batch_size] - seed item indices
            candidate_indices: [batch_size, num_candidates] - item indices
            slate_indices: [batch_size, slate_size] - partially generated slate item indices
            slate_mask: [batch_size, slate_size] - mask for autoregressive generation

        Returns:
            logits: [batch_size, slate_size, num_candidates] - scores for each position
        """
        batch_size = candidate_indices.shape[0]
        num_candidates = candidate_indices.shape[1]

        # Get embeddings for seed and candidates
        seed_embs = self.item_embedding(seed_indices).unsqueeze(1)  # [batch, 1, dim]
        candidate_embs = self.item_embedding(candidate_indices)  # [batch, num_cand, dim]

        # Add seed embedding to distinguish seed from candidates
        seed_embs = seed_embs + self.seed_embedding.unsqueeze(0).unsqueeze(0)  # [batch, 1, dim]

        # Concatenate seed with candidates for encoder input
        # This conditions the model on what we're recommending for
        encoder_input = torch.cat([seed_embs, candidate_embs], dim=1)  # [batch, 1+num_cand, dim]

        # Encode seed + candidates
        encoded = self.encoder(encoder_input)  # [batch, 1+num_cand, dim]

        # Separate seed encoding from candidate encodings
        seed_encoded = encoded[:, 0:1, :]  # [batch, 1, dim]
        candidates_encoded = encoded[:, 1:, :]  # [batch, num_cand, dim]

        if slate_indices is None:
            # Initialize with seed embedding for first position
            slate_size = 1
            slate_embs = seed_encoded  # Use seed as initial context
        else:
            slate_size = slate_indices.shape[1]
            # Get embeddings for slate items
            slate_embs = self.item_embedding(slate_indices)  # [batch, slate_size, dim]

        # Add positional encoding to slate
        positions = torch.arange(slate_size, device=candidate_indices.device)
        slate_embs = slate_embs + self.pos_embedding(positions).unsqueeze(0)

        # Decode (using full encoded sequence as memory, including seed)
        decoded = self.decoder(
            tgt=slate_embs,
            memory=encoded,  # Full context including seed
            tgt_mask=slate_mask,
        )  # [batch, slate_size, dim]

        # Project to candidate space (not including seed in output)
        # We'll compute attention scores to candidates only
        scores = torch.bmm(
            decoded, candidates_encoded.transpose(1, 2)
        )  # [batch, slate_size, num_cand]

        return scores

    def generate_slate(self, seed_indices, candidate_indices, slate_size=5, temperature=1.0):
        """Generate a slate autoregressively.

        Args:
            seed_indices: [batch_size] - seed item indices
            candidate_indices: [batch_size, num_candidates] - item indices of candidates
            slate_size: number of items to generate
            temperature: sampling temperature

        Returns:
            selected_indices: [batch_size, slate_size] - indices into candidates
            log_probs: [batch_size, slate_size] - log probabilities of selections
        """
        batch_size = candidate_indices.shape[0]
        num_candidates = candidate_indices.shape[1]

        selected_indices = []
        log_probs = []

        slate_item_indices_list = []

        for t in range(slate_size):
            if t == 0:
                slate_item_indices = None
            else:
                slate_item_indices = torch.stack(slate_item_indices_list, dim=1)  # [batch, t]

            # Get scores for next position (now includes seed)
            scores = self.forward(
                seed_indices, candidate_indices, slate_item_indices
            )  # [batch, *, num_cand]
            logits = scores[:, -1, :]  # [batch, num_cand] - scores for position t

            # Mask already selected items
            if t > 0:
                mask = torch.zeros_like(logits)
                for i, idx_list in enumerate(zip(*selected_indices)):
                    mask[i, list(idx_list)] = -1e9
                logits = logits + mask

            # Sample from categorical distribution
            probs = F.softmax(logits / temperature, dim=-1)
            dist = torch.distributions.Categorical(probs)
            selected = dist.sample()  # [batch]

            selected_indices.append(selected)
            log_probs.append(dist.log_prob(selected))

            # Get item indices of selected items for next step
            selected_item_indices = candidate_indices[torch.arange(batch_size), selected]  # [batch]
            slate_item_indices_list.append(selected_item_indices)

        selected_indices = torch.stack(selected_indices, dim=1)  # [batch, slate_size]
        log_probs = torch.stack(log_probs, dim=1)  # [batch, slate_size]

        return selected_indices, log_probs


# Initialize model
num_items = len(item_to_idx)
model = SlateGeneratorTransformer(
    num_items=num_items,
    embedding_dim=64,
    num_heads=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=256,
    dropout=0.1,
    max_slate_size=5,
).to(device)

model_stats = f"""
**Model Architecture:**

- Total parameters: {sum(p.numel() for p in model.parameters()):,}
"""
display(Markdown(model_stats))
print(model)

Model Architecture:

  • Total parameters: 553,921
SlateGeneratorTransformer(
  (item_embedding): Embedding(5000, 64)
  (pos_embedding): Embedding(5, 64)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_proj): Linear(in_features=64, out_features=1, bias=True)
)

Reward Function

The reward balances relevance to the seed item and diversity within the slate.

Show code
def compute_reward(slate_indices, seed_idx, ease_B, lambda_diversity=0.5):
    """Compute reward for a slate given a seed item using EASE similarities.

    Args:
        slate_indices: [batch_size, slate_size] - item indices in slate
        seed_idx: [batch_size] - seed item indices
        ease_B: [num_items, num_items] - EASE similarity matrix
        lambda_diversity: weight for diversity penalty

    Returns:
        rewards: [batch_size] - scalar reward for each slate
    """
    batch_size, slate_size = slate_indices.shape

    # Relevance: average EASE similarity to seed
    # Get similarities between each slate item and seed
    relevance_scores = []
    for b in range(batch_size):
        seed_item = seed_idx[b]
        slate_items = slate_indices[b]
        # Get EASE similarities from seed to each slate item
        similarities = ease_B[seed_item, slate_items]  # [slate_size]
        relevance_scores.append(similarities.mean())

    avg_relevance = torch.stack(relevance_scores)  # [batch]

    # Diversity: average pairwise EASE similarity (lower is more diverse)
    diversity_scores = []
    for b in range(batch_size):
        slate_items = slate_indices[b]
        # Get pairwise EASE similarities within the slate
        pairwise_sim = ease_B[slate_items][:, slate_items]  # [slate_size, slate_size]

        # Mask diagonal (self-similarity)
        mask = torch.eye(slate_size, device=pairwise_sim.device)
        pairwise_sim = pairwise_sim * (1 - mask)

        # Average pairwise similarity (excluding diagonal)
        avg_pairwise = pairwise_sim.sum() / (slate_size * (slate_size - 1))
        diversity_scores.append(avg_pairwise)

    avg_diversity = torch.stack(diversity_scores)  # [batch]

    # Combined reward
    reward = avg_relevance - lambda_diversity * avg_diversity

    return reward


# Test reward function
test_slate_indices = torch.tensor([[0, 1, 2, 3, 4]], device=device)  # [1, 5]
test_seed_idx = torch.tensor([0], device=device)  # [1]

# Use default lambda_diversity for testing
test_reward = compute_reward(test_slate_indices, test_seed_idx, ease_B_tensor, lambda_diversity=0.5)

test_stats = f"""
**Test Reward:**

- Test reward (λ=0.5): {test_reward.item():.4f}
"""
display(Markdown(test_stats))

Test Reward:

  • Test reward (λ=0.5): -0.0008

Training with REINFORCE

We’ll use policy gradient training with the REINFORCE algorithm (Williams, 1992) to train the model to maximize the expected reward.

Show code
class TrainingSettings(BaseSettings):
    """Configuration for slate generation training.

    Can be overridden with environment variables using SLATE_GEN_ prefix.
    Example: SLATE_GEN_NUM_EPOCHS=200
    """

    model_config = SettingsConfigDict(env_prefix="SLATE_GEN_")

    slate_size: int = Field(default=5, description="Number of items in generated slate")
    num_candidates: int = Field(default=100, description="Number of candidate items per seed")
    batch_size: int = Field(default=32, description="Training batch size")
    num_epochs: int = Field(default=100, description="Number of training epochs")
    learning_rate: float = Field(default=1e-4, description="Adam optimizer learning rate")
    lambda_diversity: float = Field(default=0.5, description="Diversity penalty weight")
    temperature: float = Field(default=1.0, description="Sampling temperature")
    num_train_seeds: int = Field(
        default=1000, description="Number of random seed items for training"
    )
    gradient_clip_norm: float = Field(default=1.0, description="Gradient clipping norm")


# Load settings (can be overridden by environment variables)
settings = TrainingSettings()
assert settings.num_epochs == 500

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=settings.learning_rate)

# Prepare training data: sample random seeds
num_items = len(item_to_idx)
seed_indices = np.random.choice(num_items, size=settings.num_train_seeds, replace=True)

# Display training configuration
config_lines = []
for field_name, field_info in TrainingSettings.model_fields.items():
    value = getattr(settings, field_name)
    desc = field_info.description
    config_lines.append(f"- {desc}: {value}")

display(
    Markdown(f"""
**Training Configuration:**

{chr(10).join(config_lines)}
""")
)

Training Configuration:

  • Number of items in generated slate: 5
  • Number of candidate items per seed: 100
  • Training batch size: 256
  • Number of training epochs: 500
  • Adam optimizer learning rate: 0.0001
  • Diversity penalty weight: 0.5
  • Sampling temperature: 1.0
  • Number of random seed items for training: 4096
  • Gradient clipping norm: 1.0
Show code
def get_candidate_set(seed_idx, ease_B, num_candidates=100):
    """Get candidate items for a seed (top-k by EASE similarity)."""
    # Get EASE similarities from seed item to all items
    similarities = ease_B[seed_idx]  # [num_items]

    # Get top-k (excluding seed itself)
    top_k = torch.topk(similarities, k=num_candidates + 1, largest=True)
    candidate_indices = top_k.indices[top_k.indices != seed_idx][:num_candidates]

    return candidate_indices


# Training loop
model.train()
rewards_history = []
loss_history = []

for epoch in range(settings.num_epochs):
    epoch_rewards = []
    epoch_losses = []

    # Sample batches
    np.random.shuffle(seed_indices)

    for i in range(0, len(seed_indices), settings.batch_size):
        batch_seeds = seed_indices[i : i + settings.batch_size]
        if len(batch_seeds) < settings.batch_size:
            continue

        # Prepare batch
        batch_candidates = []
        batch_seed_idx = []

        for seed_idx in batch_seeds:
            # Get candidates using EASE similarities
            cand_indices = get_candidate_set(seed_idx, ease_B_tensor, settings.num_candidates)
            batch_candidates.append(cand_indices)
            batch_seed_idx.append(seed_idx)

        candidate_indices = torch.stack(batch_candidates)  # [batch, num_cand]
        seed_idx_tensor = torch.tensor(batch_seed_idx, device=device)  # [batch]

        # Generate slates (now includes seed)
        selected_indices, log_probs = model.generate_slate(
            seed_idx_tensor,
            candidate_indices,
            slate_size=settings.slate_size,
            temperature=settings.temperature,
        )

        # Get item indices of selected items
        batch_idx = (
            torch.arange(settings.batch_size, device=device)
            .unsqueeze(1)
            .expand(-1, settings.slate_size)
        )
        slate_item_indices = candidate_indices[batch_idx, selected_indices]  # [batch, slate_size]

        # Compute rewards using EASE similarities
        rewards = compute_reward(
            slate_item_indices,
            seed_idx_tensor,
            ease_B_tensor,
            lambda_diversity=settings.lambda_diversity,
        )

        # REINFORCE loss: -log_prob * (reward - baseline)
        baseline = rewards.mean()
        advantages = rewards - baseline
        loss = -(log_probs.sum(dim=1) * advantages).mean()

        # Optimize
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), settings.gradient_clip_norm)
        optimizer.step()

        epoch_rewards.append(rewards.mean().item())
        epoch_losses.append(loss.item())

    avg_reward = np.mean(epoch_rewards)
    avg_loss = np.mean(epoch_losses)
    rewards_history.append(avg_reward)
    loss_history.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch {epoch + 1}/{settings.num_epochs} - Avg Reward: {avg_reward:.4f}, Loss: {avg_loss:.4f}"
        )

print("\nTraining complete!")
Epoch 10/500 - Avg Reward: 0.0148, Loss: -0.0006
Epoch 20/500 - Avg Reward: 0.0151, Loss: -0.0011
Epoch 30/500 - Avg Reward: 0.0154, Loss: -0.0012
Epoch 40/500 - Avg Reward: 0.0155, Loss: -0.0012
Epoch 50/500 - Avg Reward: 0.0157, Loss: -0.0013
Epoch 60/500 - Avg Reward: 0.0159, Loss: -0.0010
Epoch 70/500 - Avg Reward: 0.0160, Loss: -0.0012
Epoch 80/500 - Avg Reward: 0.0161, Loss: -0.0011
Epoch 90/500 - Avg Reward: 0.0162, Loss: -0.0011
Epoch 100/500 - Avg Reward: 0.0162, Loss: -0.0012
Epoch 110/500 - Avg Reward: 0.0163, Loss: -0.0014
Epoch 120/500 - Avg Reward: 0.0163, Loss: -0.0013
Epoch 130/500 - Avg Reward: 0.0163, Loss: -0.0012
Epoch 140/500 - Avg Reward: 0.0164, Loss: -0.0013
Epoch 150/500 - Avg Reward: 0.0165, Loss: -0.0011
Epoch 160/500 - Avg Reward: 0.0165, Loss: -0.0012
Epoch 170/500 - Avg Reward: 0.0165, Loss: -0.0013
Epoch 180/500 - Avg Reward: 0.0166, Loss: -0.0009
Epoch 190/500 - Avg Reward: 0.0166, Loss: -0.0010
Epoch 200/500 - Avg Reward: 0.0168, Loss: -0.0009
Epoch 210/500 - Avg Reward: 0.0168, Loss: -0.0008
Epoch 220/500 - Avg Reward: 0.0168, Loss: -0.0010
Epoch 230/500 - Avg Reward: 0.0169, Loss: -0.0012
Epoch 240/500 - Avg Reward: 0.0169, Loss: -0.0012
Epoch 250/500 - Avg Reward: 0.0169, Loss: -0.0009
Epoch 260/500 - Avg Reward: 0.0170, Loss: -0.0012
Epoch 270/500 - Avg Reward: 0.0169, Loss: -0.0009
Epoch 280/500 - Avg Reward: 0.0170, Loss: -0.0010
Epoch 290/500 - Avg Reward: 0.0171, Loss: -0.0010
Epoch 300/500 - Avg Reward: 0.0170, Loss: -0.0010
Epoch 310/500 - Avg Reward: 0.0171, Loss: -0.0011
Epoch 320/500 - Avg Reward: 0.0171, Loss: -0.0010
Epoch 330/500 - Avg Reward: 0.0172, Loss: -0.0011
Epoch 340/500 - Avg Reward: 0.0171, Loss: -0.0009
Epoch 350/500 - Avg Reward: 0.0172, Loss: -0.0012
Epoch 360/500 - Avg Reward: 0.0173, Loss: -0.0014
Epoch 370/500 - Avg Reward: 0.0173, Loss: -0.0010
Epoch 380/500 - Avg Reward: 0.0174, Loss: -0.0010
Epoch 390/500 - Avg Reward: 0.0173, Loss: -0.0013
Epoch 400/500 - Avg Reward: 0.0174, Loss: -0.0011
Epoch 410/500 - Avg Reward: 0.0175, Loss: -0.0012
Epoch 420/500 - Avg Reward: 0.0175, Loss: -0.0011
Epoch 430/500 - Avg Reward: 0.0175, Loss: -0.0014
Epoch 440/500 - Avg Reward: 0.0175, Loss: -0.0012
Epoch 450/500 - Avg Reward: 0.0176, Loss: -0.0010
Epoch 460/500 - Avg Reward: 0.0175, Loss: -0.0011
Epoch 470/500 - Avg Reward: 0.0177, Loss: -0.0010
Epoch 480/500 - Avg Reward: 0.0177, Loss: -0.0013
Epoch 490/500 - Avg Reward: 0.0177, Loss: -0.0010
Epoch 500/500 - Avg Reward: 0.0177, Loss: -0.0009

Training complete!

Training Progress

Let’s visualize how the model improved during training.

Show code
# Prepare data for plotting
training_data = pl.DataFrame(
    {
        "epoch": list(range(1, len(rewards_history) + 1)),
        "reward": rewards_history,
        "loss": loss_history,
    }
)

(
    ggplot(training_data, aes(x="epoch", y="reward"))
    + geom_line(color="#2E86AB", size=1.5)
    + geom_point(color="#2E86AB", size=2)
    + labs(title="Average Reward During Training", x="Epoch", y="Average Reward")
)

Show code
(
    ggplot(training_data, aes(x="epoch", y="loss"))
    + geom_line(color="#A23B72", size=1.5)
    + geom_point(color="#A23B72", size=2)
    + labs(title="REINFORCE Loss During Training", x="Epoch", y="Loss")
)

Generate Slates for Example Movies

Let’s test the model on some example movies and compare with a greedy baseline.

Show code
model.eval()

idx_to_item = {i: mid for mid, i in item_to_idx.items()}


def generate_slate_for_movie(movie_id, model, ease_B, movies_df, num_candidates=100, slate_size=5):
    """Generate a slate for a seed movie using the REINFORCE model."""
    if movie_id not in item_to_idx:
        return None

    seed_idx = item_to_idx[movie_id]

    # Get candidates using EASE similarities
    cand_indices = get_candidate_set(seed_idx, ease_B, num_candidates)
    cand_indices_batch = cand_indices.unsqueeze(0)  # [1, num_cand]
    seed_idx_batch = torch.tensor([seed_idx], device=device)  # [1]

    # Generate slate (now includes seed)
    with torch.no_grad():
        selected_indices, log_probs = model.generate_slate(
            seed_idx_batch,
            cand_indices_batch,
            slate_size=slate_size,
            temperature=0.5,  # Lower temperature for more deterministic selection
        )

    # Get selected item indices
    selected = cand_indices[selected_indices[0]].cpu().numpy()
    selected_movie_ids = [idx_to_item[int(idx)] for idx in selected]

    # Compute reward using EASE similarities
    slate_item_indices = torch.tensor([selected], device=device)  # [1, slate_size]
    seed_idx_tensor = torch.tensor([seed_idx], device=device)  # [1]
    reward = compute_reward(
        slate_item_indices, seed_idx_tensor, ease_B, lambda_diversity=settings.lambda_diversity
    )

    return {
        "selected_ids": selected_movie_ids,
        "reward": reward.item(),
        "log_probs": log_probs[0].cpu().numpy(),
    }


def greedy_baseline(seed_idx, ease_B, slate_size=5):
    """Greedy baseline: select top-k most similar items using EASE."""
    # Get EASE similarities from seed to all items
    similarities = ease_B[seed_idx]

    # Get top-k (excluding seed)
    top_k = torch.topk(similarities, k=slate_size + 1, largest=True)
    selected = top_k.indices[top_k.indices != seed_idx][:slate_size]

    return selected


def compute_avg_similarity_to_seed(slate_indices, seed_idx, ease_B):
    """Compute average EASE similarity between slate items and seed."""
    similarities = ease_B[seed_idx, slate_indices]
    return similarities.mean().item()
Show code
# Example seed movies
EXAMPLE_MOVIE_IDS = [1, 260, 1196]  # Toy Story, Star Wars, Star Wars V

for movie_id in EXAMPLE_MOVIE_IDS:
    seed_info = movies.filter(pl.col("movie_id") == movie_id)
    if len(seed_info) == 0 or movie_id not in item_to_idx:
        continue

    seed_idx = item_to_idx[movie_id]
    seed_title = seed_info["title"][0]

    # Generate REINFORCE model slate
    model_result = generate_slate_for_movie(movie_id, model, ease_B_tensor, movies)
    if model_result is None:
        continue

    # Convert to numpy array first, then to tensor to avoid warning
    model_slate_idx_np = np.array(
        [item_to_idx[mid] for mid in model_result["selected_ids"]], dtype=np.int64
    )
    model_slate_idx = torch.from_numpy(model_slate_idx_np).to(device)
    model_slate_indices = model_slate_idx.unsqueeze(0)  # [1, slate_size]
    seed_idx_tensor = torch.tensor([seed_idx], device=device)  # [1]
    model_reward = compute_reward(
        model_slate_indices, seed_idx_tensor, ease_B_tensor, settings.lambda_diversity
    ).item()
    model_avg_sim = compute_avg_similarity_to_seed(model_slate_idx, seed_idx, ease_B_tensor)

    # Generate greedy baseline slate
    baseline_slate_idx = greedy_baseline(seed_idx, ease_B_tensor, slate_size=settings.slate_size)
    baseline_slate_indices = baseline_slate_idx.unsqueeze(0)  # [1, slate_size]
    baseline_reward = compute_reward(
        baseline_slate_indices, seed_idx_tensor, ease_B_tensor, settings.lambda_diversity
    ).item()
    baseline_avg_sim = compute_avg_similarity_to_seed(baseline_slate_idx, seed_idx, ease_B_tensor)

    # Display seed movie header
    display(Markdown(f"### 🎬 Seed Movie: {seed_title}\n"))

    # Get and display seed poster
    seed_poster = posters.filter(pl.col("movie_id") == movie_id)
    if len(seed_poster) > 0 and seed_poster["poster_path"][0]:
        display(tmdb_images([seed_poster["poster_path"][0]]))

    # Display metrics comparison
    display(
        Markdown(f"""
| Metric | REINFORCE Model | Greedy Baseline |
|--------|---------------:|----------------:|
| **Reward** | {model_reward:.4f} | {baseline_reward:.4f} |
| **Avg Similarity to Seed** | {model_avg_sim:.4f} | {baseline_avg_sim:.4f} |
| **Improvement** | - | {model_reward - baseline_reward:+.4f} |
""")
    )

    # Display REINFORCE model slate
    display(Markdown("\n#### 🤖 REINFORCE Model Slate\n"))

    # Collect posters for model slate
    model_poster_paths = []
    for mid in model_result["selected_ids"]:
        movie_poster = posters.filter(pl.col("movie_id") == mid)
        if len(movie_poster) > 0 and movie_poster["poster_path"][0]:
            model_poster_paths.append(movie_poster["poster_path"][0])

    if model_poster_paths:
        display(tmdb_images(model_poster_paths))

    # Display greedy baseline slate
    display(Markdown("\n#### 📊 Greedy Baseline Slate\n"))

    # Collect posters for baseline slate
    baseline_poster_paths = []
    baseline_movie_ids = [idx_to_item[int(idx)] for idx in baseline_slate_idx.cpu().numpy()]

    for mid in baseline_movie_ids:
        movie_poster = posters.filter(pl.col("movie_id") == mid)
        if len(movie_poster) > 0 and movie_poster["poster_path"][0]:
            baseline_poster_paths.append(movie_poster["poster_path"][0])

    if baseline_poster_paths:
        display(tmdb_images(baseline_poster_paths))

    display(Markdown("\n---\n"))
/tmp/ipykernel_3834783/1299946720.py:32: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:253.)

🎬 Seed Movie: Toy Story (1995)

Metric REINFORCE Model Greedy Baseline
Reward 0.0336 0.0639
Avg Similarity to Seed 0.0350 0.0708
Improvement - -0.0303

🤖 REINFORCE Model Slate

📊 Greedy Baseline Slate


🎬 Seed Movie: Star Wars: Episode IV - A New Hope (1977)

Metric REINFORCE Model Greedy Baseline
Reward 0.0292 0.1206
Avg Similarity to Seed 0.0310 0.1429
Improvement - -0.0914

🤖 REINFORCE Model Slate

📊 Greedy Baseline Slate


🎬 Seed Movie: Star Wars: Episode V - The Empire Strikes Back (1980)

Metric REINFORCE Model Greedy Baseline
Reward 0.0810 0.1455
Avg Similarity to Seed 0.0788 0.1691
Improvement - -0.0645

🤖 REINFORCE Model Slate

📊 Greedy Baseline Slate


Summary

Model Architecture:

  • Encoder-Decoder Transformer
  • Autoregressive slate generation

Reward Function:

  • Relevance to seed (avg cosine similarity)
  • Inter-slate similarity penalized (λ)
  • Balances similarity and variety

Results:

  • Model learns to generate diverse slates
  • Outperforms greedy baseline

Key Insights:

  1. Sequential generation allows modeling item interactions
  2. Policy gradients optimize for long-term slate quality
  3. Diversity constraint prevents redundant recommendations

References

Bello, I., Kulkarni, S., Jain, S., Boutilier, C., Chi, E., Eban, E., Luo, X., Mackey, A., & Meshi, O. (2018). Seq2Slate: Re-ranking and slate optimization with RNNs. arXiv Preprint arXiv:1810.02019. https://arxiv.org/abs/1810.02019
Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3-4), 229–256. https://doi.org/10.1007/BF00992696