import dataclassesimport datetimeimport os import datasetsimport tokenizersimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim.lr_scheduler as lr_schedulerimport tqdmfrom torch import Tensorfrom torch.distributed.checkpoint import load, savefrom torch.distributed.checkpoint.default_planner import DefaultLoadPlannerfrom torch.distributed.fsdp import FSDPModule, fully_shardfrom torch.distributed.tensor import Replicate, Shardfrom torch.distributed.tensor.parallel import ( ColwiseParallel, PrepareModuleInput, RowwiseParallel, SequenceParallel, loss_parallel, parallelize_module,)from torch.utils.data.distributed import DistributedSampler # Set default to bfloat16torch.set_default_dtype(torch.bfloat16)print("NCCL version:", torch.cuda.nccl.version()) # Build the model@dataclasses.dataclassclass LlamaConfig: """Define Llama…
Train Your Large Model on Multiple GPUs with Tensor Parallelism
import dataclasses
import datetime
import os
import datasets
import tokenizers
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import tqdm
from torch import Tensor
from torch.distributed.checkpoint import load, save
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
loss_parallel,
parallelize_module,
)
from torch.utils.data.distributed import DistributedSampler
# Set default to bfloat16
torch.set_default_dtype(torch.bfloat16)
print("NCCL version:", torch.cuda.nccl.version())
# Build the model
@dataclasses.dataclass
class LlamaConfig:
"""Define Llama model hyperparameters."""
vocab_size: int = 50000 # Size of the tokenizer vocabulary
max_position_embeddings: int = 2048 # Maximum sequence length
hidden_size: int = 768 # Dimension of hidden layers
intermediate_size: int = 4*768 # Dimension of MLP's hidden layer
num_hidden_layers: int = 12 # Number of transformer layers
num_attention_heads: int = 12 # Number of attention heads
num_key_value_heads: int = 3 # Number of key-value heads for GQA
class RotaryPositionEncoding(nn.Module):
"""Rotary position encoding."""
def __init__(self, dim: int, max_position_embeddings: int) -> None:
"""Initialize the RotaryPositionEncoding module.
Args:
dim: The hidden dimension of the input tensor to which RoPE is applied
max_position_embeddings: The maximum sequence length of the input tensor
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
# compute a matrix of n\theta_i
N = 10_000.0
inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))
inv_freq = torch.cat((inv_freq, inv_freq), dim=-1)
position = torch.arange(max_position_embeddings)
sinusoid_inp = torch.outer(position, inv_freq)
# save cosine and sine matrices as buffers, not parameters
self.register_buffer("cos", sinusoid_inp.cos())
self.register_buffer("sin", sinusoid_inp.sin())
def forward(self, x: Tensor) -> Tensor:
"""Apply RoPE to tensor x.
Args:
x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)
Returns:
Output tensor of shape (batch_size, seq_length, num_heads, head_dim)
"""
batch_size, seq_len, num_heads, head_dim = x.shape
device = x.device
dtype = x.dtype
# transform the cosine and sine matrices to 4D tensor and the same dtype as x
cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)
sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)
# apply RoPE to x
x1, x2 = x.chunk(2, dim=-1)
rotated = torch.cat((-x2, x1), dim=-1)
output = (x * cos) + (rotated * sin)
return output
class LlamaAttention(nn.Module):
"""Grouped-query attention with rotary embeddings."""
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_heads = config.num_key_value_heads # GQA: H_kv < H_q
# hidden_size must be divisible by num_heads
assert (self.head_dim * self.num_heads) == self.hidden_size
# Linear layers for Q, K, V projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
bs, seq_len, dim = hidden_states.size()
# Project inputs to Q, K, V
query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
# Apply rotary position embeddings
query_states = rope(query_states)
key_states = rope(key_states)
# Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Use PyTorch's optimized attention implementation
# setting is_causal=True is incompatible with setting explicit attention mask
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=0.0,
enable_gqa=True,
)
# Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output
attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
class LlamaMLP(nn.Module):
"""Feed-forward network with SwiGLU activation."""
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
# Two parallel projections for SwiGLU
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.act_fn = F.silu # SwiGLU activation function
# Project back to hidden size
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor) -> Tensor:
# SwiGLU activation: multiply gate and up-projected inputs
gate = self.act_fn(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class LlamaDecoderLayer(nn.Module):
"""Single transformer layer for a Llama model."""
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
self.self_attn = LlamaAttention(config)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
self.mlp = LlamaMLP(config)
def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:
# First residual block: Self-attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)
hidden_states = attn_outputs + residual
# Second residual block: MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) + residual
return hidden_states
class LlamaModel(nn.Module):
"""The full Llama model without any pretraining heads."""
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.rotary_emb = RotaryPositionEncoding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings,
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5)
def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
# Convert input token IDs to embeddings
hidden_states = self.embed_tokens(input_ids)
# Process through all transformer layers, then the final norm layer
for layer in self.layers:
hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)
hidden_states = self.norm(hidden_states)
# Return the final hidden states
return hidden_states
class LlamaForPretraining(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.base_model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:
hidden_states = self.base_model(input_ids, attn_mask)
return self.lm_head(hidden_states)
def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:
"""Create a causal mask for self-attention.
Args:
batch: Batch of sequences, shape (batch_size, seq_len)
dtype: Data type of the mask
Returns:
Causal mask of shape (seq_len, seq_len)
"""
batch_size, seq_len = batch.shape
mask = torch.full((seq_len, seq_len), float("-inf"), device=batch.device, dtype=dtype) \
.triu(diagonal=1)
return mask
def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:
"""Create a padding mask for a batch of sequences for self-attention.
Args:
batch: Batch of sequences, shape (batch_size, seq_len)
padding_token_id: ID of the padding token
dtype: Data type of the mask
Returns:
Padding mask of shape (batch_size, 1, seq_len, seq_len)
"""
padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \
.masked_fill(batch == padding_token_id, float("-inf"))
mask = padded[:,:,None] + padded[:,None,:]
return mask[:, None, :, :]
# Generator function to create padded sequences of fixed length
class PretrainingDataset(torch.utils.data.Dataset):
def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,
seq_length: int):
self.dataset = dataset
self.tokenizer = tokenizer
self.seq_length = seq_length
self.bot = tokenizer.token_to_id("[BOT]")
self.eot = tokenizer.token_to_id("[EOT]")
self.pad = tokenizer.token_to_id("[PAD]")
def __len__(self):
return len(self.dataset)
def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
"""Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens
are added. Clipped and padded to the sequence length.
"""
seq = self.dataset[index]["text"]
tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]
# pad to target sequence length
toklen = len(tokens)
if toklen < self.seq_length+1:
pad_length = self.seq_length+1 - toklen
tokens += [self.pad] * pad_length
# return the sequence
x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)
y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)
return x, y
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
dist.barrier()
load(
{"model": model, "optimizer": optimizer},
checkpoint_id="checkpoint-dist",
planner=DefaultLoadPlanner(allow_partial_load=True), # ignore keys for RoPE buffer
)
scheduler.load_state_dict(
torch.load("checkpoint-dist/lrscheduler.pt", map_location=device),
)
dist.barrier()
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:
dist.barrier()
save(
{"model": model, "optimizer": optimizer},
checkpoint_id="checkpoint-dist",
)
if dist.get_rank() == 0:
torch.save(scheduler.state_dict(), "checkpoint-dist/lrscheduler.pt")
dist.barrier()
# Load the tokenizer and dataset
tokenizer = tokenizers.Tokenizer.from_file("bpe_50K.json")
dataset = datasets.load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train")
# Initialize the distributed environment
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=60))
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}")
# Initialize the mesh for tensor parallelism
n_tensor_parallel = 2
assert world_size % n_tensor_parallel == 0, "Expect world size to be divisible by number of tensor parallel GPUs"
mesh = dist.device_mesh.init_device_mesh(
"cuda",
(world_size // n_tensor_parallel, n_tensor_parallel),
mesh_dim_names=("dp", "tp"),
)
print(f"({rank}) Mesh: {mesh}, DP size: {mesh['dp'].size()}, TP size: {mesh['tp'].size()}, DP local rank: {mesh['dp'].get_local_rank()}, TP local rank: {mesh['tp'].get_local_rank()}")
# Create pretraining model on meta device, on all ranks
with torch.device("meta"):
model_config = LlamaConfig()
model = LlamaForPretraining(model_config)
# Set up tensor parallelism on each transformer block in the base model
tp_plan = {
"input_layernorm": SequenceParallel(),
"self_attn": PrepareModuleInput(
input_layouts=Shard(dim=1), # only one position arg will be used
desired_input_layouts=Replicate(),
),
# Q/K projections output will be used with RoPE, need to be replicated
# Q/K/V output will be used with GQA, also need to be replicated
"self_attn.q_proj": ColwiseParallel(output_layouts=Replicate()),
"self_attn.k_proj": ColwiseParallel(output_layouts=Replicate()),
"self_attn.v_proj": ColwiseParallel(output_layouts=Replicate()),
"self_attn.o_proj": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
"post_attention_layernorm": SequenceParallel(),
"mlp": PrepareModuleInput(
input_layouts=Shard(dim=1),
desired_input_layouts=Replicate(),
),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
}
for layer in model.base_model.layers:
parallelize_module(layer, mesh["tp"], tp_plan)
# Set up tensor parallelism on the embedding and output norm layers in the base model
# and the prediction head in the top-level model
tp_plan = {
"base_model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"base_model.norm": SequenceParallel(),
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
# output_layouts=Replicate(), # only if not using loss parallel
use_local_output=False, # Keep DTensor output for loss parallel
),
}
parallelize_module(model, mesh["tp"], tp_plan)
# Convert tensor-parallelized model to FSDP2, must shard every component
# shard across the "dp" dimension of the mesh
for layer in model.base_model.layers:
fully_shard(layer, mesh=mesh["dp"])
fully_shard(model.base_model, mesh=mesh["dp"])
fully_shard(model, mesh=mesh["dp"])
def reset_all_weights(model: nn.Module) -> None:
"""Initialize all weights of the model after moving it away from meta device."""
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to model itself and all of model.children()
model.apply(fn=weight_reset)
torch.manual_seed(42)
model.to_empty(device=device)
reset_all_weights(model)
assert isinstance(model, FSDPModule), f"Expected FSDPModule, got {type(model)}"
# Training parameters
epochs = 3
learning_rate = 1e-3
batch_size = 64 // mesh["dp"].size()
seq_length = 512
num_warmup_steps = 1000
PAD_TOKEN_ID = tokenizer.token_to_id("[PAD]")
model.train()
# DataLoader, optimizer, scheduler, and loss function
# Sampler is needed to shard the dataset across world size
dataset = PretrainingDataset(dataset, tokenizer, seq_length)
sampler = DistributedSampler(
dataset, shuffle=False, drop_last=True,
num_replicas=mesh["dp"].size(),
rank=mesh["dp"].get_local_rank(),
)
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True, # optional
shuffle=False,
num_workers=2,
prefetch_factor=2,
)
num_training_steps = len(dataloader) * epochs
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.1,
)
warmup_scheduler = lr_scheduler.LinearLR(
optimizer,
start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,
)
cosine_scheduler = lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=num_training_steps - num_warmup_steps,
eta_min=0,
)
scheduler = lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[num_warmup_steps],
)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)
# if checkpoint-dist dir exists, load the checkpoint to model and optimizer
if os.path.exists("checkpoint-dist"):
load_checkpoint(model, optimizer, scheduler)
# start training
print(f"({rank}) Starting training")
for epoch in range(epochs):
pbar = tqdm.tqdm(dataloader, desc=f"({rank}) Epoch {epoch+1}/{epochs}")
for batch_id, batch in enumerate(pbar):
if batch_id % 1000 == 0:
save_checkpoint(model, optimizer, scheduler)
# Explicit prefetching before sending any data to model
model.unshard()
# Get batched data, move from CPU to GPU
input_ids, target_ids = batch
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
# create attention mask: causal mask + padding mask
attn_mask = create_causal_mask(input_ids) + \
create_padding_mask(input_ids, PAD_TOKEN_ID)
# Extract output from model
logits = model(input_ids, attn_mask)
optimizer.zero_grad()
with loss_parallel():
# Compute loss: cross-entropy between logits and target, ignoring padding tokens
loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1))
# Backward with loss on DTensor
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
pbar.set_postfix(loss=loss.item())
pbar.update(1)
pbar.close()
# Save the model
save_checkpoint(model, optimizer, scheduler)
# Clean up the distributed environment
dist.destroy_process_group()