class SlateGeneratorTransformer(nn.Module):
"""Encoder-Decoder Transformer for Slate Generation.
Encoder: Process seed item + candidate items (using trainable embeddings)
Decoder: Autoregressively generate slate positions conditioned on seed
"""
def __init__(
self,
num_items: int,
embedding_dim: int = 64,
num_heads: int = 4,
num_encoder_layers: int = 2,
num_decoder_layers: int = 2,
dim_feedforward: int = 256,
dropout: float = 0.1,
max_slate_size: int = 5,
):
super().__init__()
self.embedding_dim = embedding_dim
self.max_slate_size = max_slate_size
# Trainable item embeddings (replacing content embeddings)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
nn.init.normal_(self.item_embedding.weight, mean=0.0, std=0.1)
# Seed embedding to mark the seed item in encoder input
# This gets added to the seed item embedding to distinguish it from candidates
self.seed_embedding = nn.Parameter(torch.zeros(embedding_dim))
nn.init.normal_(self.seed_embedding, mean=0.0, std=0.1)
# Positional encoding for decoder
self.pos_embedding = nn.Embedding(max_slate_size, embedding_dim)
# Transformer layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
decoder_layer = nn.TransformerDecoderLayer(
d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
# Output projection to candidate scores
self.output_proj = nn.Linear(embedding_dim, 1)
def forward(self, seed_indices, candidate_indices, slate_indices=None, slate_mask=None):
"""
Args:
seed_indices: [batch_size] - seed item indices
candidate_indices: [batch_size, num_candidates] - item indices
slate_indices: [batch_size, slate_size] - partially generated slate item indices
slate_mask: [batch_size, slate_size] - mask for autoregressive generation
Returns:
logits: [batch_size, slate_size, num_candidates] - scores for each position
"""
batch_size = candidate_indices.shape[0]
num_candidates = candidate_indices.shape[1]
# Get embeddings for seed and candidates
seed_embs = self.item_embedding(seed_indices).unsqueeze(1) # [batch, 1, dim]
candidate_embs = self.item_embedding(candidate_indices) # [batch, num_cand, dim]
# Add seed embedding to distinguish seed from candidates
seed_embs = seed_embs + self.seed_embedding.unsqueeze(0).unsqueeze(0) # [batch, 1, dim]
# Concatenate seed with candidates for encoder input
# This conditions the model on what we're recommending for
encoder_input = torch.cat([seed_embs, candidate_embs], dim=1) # [batch, 1+num_cand, dim]
# Encode seed + candidates
encoded = self.encoder(encoder_input) # [batch, 1+num_cand, dim]
# Separate seed encoding from candidate encodings
seed_encoded = encoded[:, 0:1, :] # [batch, 1, dim]
candidates_encoded = encoded[:, 1:, :] # [batch, num_cand, dim]
if slate_indices is None:
# Initialize with seed embedding for first position
slate_size = 1
slate_embs = seed_encoded # Use seed as initial context
else:
slate_size = slate_indices.shape[1]
# Get embeddings for slate items
slate_embs = self.item_embedding(slate_indices) # [batch, slate_size, dim]
# Add positional encoding to slate
positions = torch.arange(slate_size, device=candidate_indices.device)
slate_embs = slate_embs + self.pos_embedding(positions).unsqueeze(0)
# Decode (using full encoded sequence as memory, including seed)
decoded = self.decoder(
tgt=slate_embs,
memory=encoded, # Full context including seed
tgt_mask=slate_mask,
) # [batch, slate_size, dim]
# Project to candidate space (not including seed in output)
# We'll compute attention scores to candidates only
scores = torch.bmm(
decoded, candidates_encoded.transpose(1, 2)
) # [batch, slate_size, num_cand]
return scores
def generate_slate(self, seed_indices, candidate_indices, slate_size=5, temperature=1.0):
"""Generate a slate autoregressively.
Args:
seed_indices: [batch_size] - seed item indices
candidate_indices: [batch_size, num_candidates] - item indices of candidates
slate_size: number of items to generate
temperature: sampling temperature
Returns:
selected_indices: [batch_size, slate_size] - indices into candidates
log_probs: [batch_size, slate_size] - log probabilities of selections
"""
batch_size = candidate_indices.shape[0]
num_candidates = candidate_indices.shape[1]
selected_indices = []
log_probs = []
slate_item_indices_list = []
for t in range(slate_size):
if t == 0:
slate_item_indices = None
else:
slate_item_indices = torch.stack(slate_item_indices_list, dim=1) # [batch, t]
# Get scores for next position (now includes seed)
scores = self.forward(
seed_indices, candidate_indices, slate_item_indices
) # [batch, *, num_cand]
logits = scores[:, -1, :] # [batch, num_cand] - scores for position t
# Mask already selected items
if t > 0:
mask = torch.zeros_like(logits)
for i, idx_list in enumerate(zip(*selected_indices)):
mask[i, list(idx_list)] = -1e9
logits = logits + mask
# Sample from categorical distribution
probs = F.softmax(logits / temperature, dim=-1)
dist = torch.distributions.Categorical(probs)
selected = dist.sample() # [batch]
selected_indices.append(selected)
log_probs.append(dist.log_prob(selected))
# Get item indices of selected items for next step
selected_item_indices = candidate_indices[torch.arange(batch_size), selected] # [batch]
slate_item_indices_list.append(selected_item_indices)
selected_indices = torch.stack(selected_indices, dim=1) # [batch, slate_size]
log_probs = torch.stack(log_probs, dim=1) # [batch, slate_size]
return selected_indices, log_probs
# Initialize model
num_items = len(item_to_idx)
model = SlateGeneratorTransformer(
num_items=num_items,
embedding_dim=64,
num_heads=4,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=256,
dropout=0.1,
max_slate_size=5,
).to(device)
model_stats = f"""
**Model Architecture:**
- Total parameters: {sum(p.numel() for p in model.parameters()):,}
"""
display(Markdown(model_stats))
print(model)