← Back to News
General

Train Your Large Model on Multiple GPUs with Tensor Parallelism

✍️ thecrossroadtimes.com
📅 December 31, 2025
⏱️ Dec 31, 2025
👁️ 29 views

import dataclassesimport datetimeimport os import datasetsimport tokenizersimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim.lr_scheduler as lr_schedulerimport tqdmfrom torch import Tensorfrom torch.distributed.checkpoint import load, savefrom torch.distributed.checkpoint.default_planner import DefaultLoadPlannerfrom torch.distributed.fsdp import FSDPModule, fully_shardfrom torch.distributed.tensor import Replicate, Shardfrom torch.distributed.tensor.parallel import (    ColwiseParallel,    PrepareModuleInput,    RowwiseParallel,    SequenceParallel,    loss_parallel,    parallelize_module,)from torch.utils.data.distributed import DistributedSampler # Set default to bfloat16torch.set_default_dtype(torch.bfloat16)print("NCCL version:", torch.cuda.nccl.version()) # Build the model@dataclasses.dataclassclass LlamaConfig:    """Define Llama…


import dataclasses

import datetime

import os

 

import datasets

import tokenizers

import torch

import torch.distributed as dist

import torch.nn as nn

import torch.nn.functional as F

import torch.optim.lr_scheduler as lr_scheduler

import tqdm

from torch import Tensor

from torch.distributed.checkpoint import load, save

from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner

from torch.distributed.fsdp import FSDPModule, fully_shard

from torch.distributed.tensor import Replicate, Shard

from torch.distributed.tensor.parallel import (

    ColwiseParallel,

    PrepareModuleInput,

    RowwiseParallel,

    SequenceParallel,

    loss_parallel,

    parallelize_module,

)

from torch.utils.data.distributed import DistributedSampler

 

# Set default to bfloat16

torch.set_default_dtype(torch.bfloat16)

print("NCCL version:", torch.cuda.nccl.version())

 

# Build the model

@dataclasses.dataclass

class LlamaConfig:

    """Define Llama model hyperparameters."""

    vocab_size: int = 50000  # Size of the tokenizer vocabulary

    max_position_embeddings: int = 2048  # Maximum sequence length

    hidden_size: int = 768  # Dimension of hidden layers

    intermediate_size: int = 4*768  # Dimension of MLP's hidden layer

    num_hidden_layers: int = 12  # Number of transformer layers

    num_attention_heads: int = 12  # Number of attention heads

    num_key_value_heads: int = 3  # Number of key-value heads for GQA

 

 

class RotaryPositionEncoding(nn.Module):

    """Rotary position encoding."""

 

    def __init__(self, dim: int, max_position_embeddings: int) -> None:

        """Initialize the RotaryPositionEncoding module.

 

        Args:

            dim: The hidden dimension of the input tensor to which RoPE is applied

            max_position_embeddings: The maximum sequence length of the input tensor

        """

        super().__init__()

        self.dim = dim

        self.max_position_embeddings = max_position_embeddings

        # compute a matrix of n\theta_i

        N = 10_000.0

        inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

        inv_freq = torch.cat((inv_freq, inv_freq), dim=-1)

        position = torch.arange(max_position_embeddings)

        sinusoid_inp = torch.outer(position, inv_freq)

        # save cosine and sine matrices as buffers, not parameters

        self.register_buffer("cos", sinusoid_inp.cos())

        self.register_buffer("sin", sinusoid_inp.sin())

 

    def forward(self, x: Tensor) -> Tensor:

        """Apply RoPE to tensor x.

 

        Args:

            x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)

 

        Returns:

            Output tensor of shape (batch_size, seq_length, num_heads, head_dim)

        """

        batch_size, seq_len, num_heads, head_dim = x.shape

        device = x.device

        dtype = x.dtype

        # transform the cosine and sine matrices to 4D tensor and the same dtype as x

        cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)

        sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)

        # apply RoPE to x

        x1, x2 = x.chunk(2, dim=-1)

        rotated = torch.cat((-x2, x1), dim=-1)

        output = (x * cos) + (rotated * sin)

        return output

 

 

class LlamaAttention(nn.Module):

    """Grouped-query attention with rotary embeddings."""

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.hidden_size = config.hidden_size

        self.num_heads = config.num_attention_heads

        self.head_dim = self.hidden_size // self.num_heads

        self.num_kv_heads = config.num_key_value_heads  # GQA: H_kv < H_q

 

        # hidden_size must be divisible by num_heads

        assert (self.head_dim * self.num_heads) == self.hidden_size

 

        # Linear layers for Q, K, V projections

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        bs, seq_len, dim = hidden_states.size()

 

        # Project inputs to Q, K, V

        query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

        key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

        value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

 

        # Apply rotary position embeddings

        query_states = rope(query_states)

        key_states = rope(key_states)

 

        # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

        query_states = query_states.transpose(1, 2)

        key_states = key_states.transpose(1, 2)

        value_states = value_states.transpose(1, 2)

 

        # Use PyTorch's optimized attention implementation

        # setting is_causal=True is incompatible with setting explicit attention mask

        attn_output = F.scaled_dot_product_attention(

            query_states,

            key_states,

            value_states,

            attn_mask=attn_mask,

            dropout_p=0.0,

            enable_gqa=True,

        )

 

        # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output

        attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output

 

 

class LlamaMLP(nn.Module):

    """Feed-forward network with SwiGLU activation."""

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        # Two parallel projections for SwiGLU

        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.act_fn = F.silu  # SwiGLU activation function

        # Project back to hidden size

        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

 

    def forward(self, x: Tensor) -> Tensor:

        # SwiGLU activation: multiply gate and up-projected inputs

        gate = self.act_fn(self.gate_proj(x))

        up = self.up_proj(x)

        return self.down_proj(gate * up)

 

 

class LlamaDecoderLayer(nn.Module):

    """Single transformer layer for a Llama model."""

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)

        self.self_attn = LlamaAttention(config)

        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)

        self.mlp = LlamaMLP(config)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

        # First residual block: Self-attention

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)

        hidden_states = attn_outputs + residual

 

        # Second residual block: MLP

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states) + residual

        return hidden_states

 

 

class LlamaModel(nn.Module):

    """The full Llama model without any pretraining heads."""

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.rotary_emb = RotaryPositionEncoding(

            config.hidden_size // config.num_attention_heads,

            config.max_position_embeddings,

        )

 

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

        self.layers = nn.ModuleList([

            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)

        ])

        self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        # Convert input token IDs to embeddings

        hidden_states = self.embed_tokens(input_ids)

        # Process through all transformer layers, then the final norm layer

        for layer in self.layers:

            hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)

        hidden_states = self.norm(hidden_states)

        # Return the final hidden states

        return hidden_states

 

 

class LlamaForPretraining(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.base_model = LlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

 

    def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

        hidden_states = self.base_model(input_ids, attn_mask)

        return self.lm_head(hidden_states)

 

 

def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:

    """Create a causal mask for self-attention.

 

    Args:

        batch: Batch of sequences, shape (batch_size, seq_len)

        dtype: Data type of the mask

 

    Returns:

        Causal mask of shape (seq_len, seq_len)

    """

    batch_size, seq_len = batch.shape

    mask = torch.full((seq_len, seq_len), float("-inf"), device=batch.device, dtype=dtype) \

                .triu(diagonal=1)

    return mask

 

 

def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:

    """Create a padding mask for a batch of sequences for self-attention.

 

    Args:

        batch: Batch of sequences, shape (batch_size, seq_len)

        padding_token_id: ID of the padding token

        dtype: Data type of the mask

 

    Returns:

        Padding mask of shape (batch_size, 1, seq_len, seq_len)

    """

    padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \

                  .masked_fill(batch == padding_token_id, float("-inf"))

    mask = padded[:,:,None] + padded[:,None,:]

    return mask[:, None, :, :]

 

 

# Generator function to create padded sequences of fixed length

class PretrainingDataset(torch.utils.data.Dataset):

    def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,

                 seq_length: int):

        self.dataset = dataset

        self.tokenizer = tokenizer

        self.seq_length = seq_length

        self.bot = tokenizer.token_to_id("[BOT]")

        self.eot = tokenizer.token_to_id("[EOT]")

        self.pad = tokenizer.token_to_id("[PAD]")

 

    def __len__(self):

        return len(self.dataset)

 

    def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:

        """Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens

        are added. Clipped and padded to the sequence length.

        """

        seq = self.dataset[index]["text"]

        tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]

        # pad to target sequence length

        toklen = len(tokens)

        if toklen < self.seq_length+1:

            pad_length = self.seq_length+1 - toklen

            tokens += [self.pad] * pad_length

        # return the sequence

        x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)

        y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)

        return x, y

 

 

def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

    dist.barrier()

    load(

        {"model": model, "optimizer": optimizer},

        checkpoint_id="checkpoint-dist",

        planner=DefaultLoadPlanner(allow_partial_load=True),  # ignore keys for RoPE buffer

    )

    scheduler.load_state_dict(

        torch.load("checkpoint-dist/lrscheduler.pt", map_location=device),

    )

    dist.barrier()

 

 

def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

    dist.barrier()

    save(

        {"model": model, "optimizer": optimizer},

        checkpoint_id="checkpoint-dist",

    )

    if dist.get_rank() == 0:

        torch.save(scheduler.state_dict(), "checkpoint-dist/lrscheduler.pt")

    dist.barrier()

 

 

# Load the tokenizer and dataset

tokenizer = tokenizers.Tokenizer.from_file("bpe_50K.json")

dataset = datasets.load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train")

 

# Initialize the distributed environment

dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=60))

local_rank = int(os.environ["LOCAL_RANK"])

device = torch.device(f"cuda:{local_rank}")

rank = dist.get_rank()

world_size = dist.get_world_size()

print(f"World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}")

 

# Initialize the mesh for tensor parallelism

n_tensor_parallel = 2

assert world_size % n_tensor_parallel == 0, "Expect world size to be divisible by number of tensor parallel GPUs"

mesh = dist.device_mesh.init_device_mesh(

    "cuda",

    (world_size // n_tensor_parallel, n_tensor_parallel),

    mesh_dim_names=("dp", "tp"),

)

print(f"({rank}) Mesh: {mesh}, DP size: {mesh['dp'].size()}, TP size: {mesh['tp'].size()}, DP local rank: {mesh['dp'].get_local_rank()}, TP local rank: {mesh['tp'].get_local_rank()}")

 

# Create pretraining model on meta device, on all ranks

with torch.device("meta"):

    model_config = LlamaConfig()

    model = LlamaForPretraining(model_config)

 

# Set up tensor parallelism on each transformer block in the base model

tp_plan = {

    "input_layernorm": SequenceParallel(),

    "self_attn": PrepareModuleInput(

        input_layouts=Shard(dim=1),  # only one position arg will be used

        desired_input_layouts=Replicate(),

    ),

    # Q/K projections output will be used with RoPE, need to be replicated

    # Q/K/V output will be used with GQA, also need to be replicated

    "self_attn.q_proj": ColwiseParallel(output_layouts=Replicate()),

    "self_attn.k_proj": ColwiseParallel(output_layouts=Replicate()),

    "self_attn.v_proj": ColwiseParallel(output_layouts=Replicate()),

    "self_attn.o_proj": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),

    "post_attention_layernorm": SequenceParallel(),

    "mlp": PrepareModuleInput(

        input_layouts=Shard(dim=1),

        desired_input_layouts=Replicate(),

    ),

    "mlp.gate_proj": ColwiseParallel(),

    "mlp.up_proj": ColwiseParallel(),

    "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),

}

for layer in model.base_model.layers:

    parallelize_module(layer, mesh["tp"], tp_plan)

 

# Set up tensor parallelism on the embedding and output norm layers in the base model

# and the prediction head in the top-level model

tp_plan = {

    "base_model.embed_tokens": RowwiseParallel(

        input_layouts=Replicate(),

        output_layouts=Shard(1),

    ),

    "base_model.norm": SequenceParallel(),

    "lm_head": ColwiseParallel(

        input_layouts=Shard(1),

        # output_layouts=Replicate(), # only if not using loss parallel

        use_local_output=False,  # Keep DTensor output for loss parallel

    ),

}

parallelize_module(model, mesh["tp"], tp_plan)

 

# Convert tensor-parallelized model to FSDP2, must shard every component

# shard across the "dp" dimension of the mesh

for layer in model.base_model.layers:

    fully_shard(layer, mesh=mesh["dp"])

fully_shard(model.base_model, mesh=mesh["dp"])

fully_shard(model, mesh=mesh["dp"])

 

def reset_all_weights(model: nn.Module) -> None:

    """Initialize all weights of the model after moving it away from meta device."""

    @torch.no_grad()

    def weight_reset(m: nn.Module):

        reset_parameters = getattr(m, "reset_parameters", None)

        if callable(reset_parameters):

            m.reset_parameters()

 

    # Applies fn recursively to model itself and all of model.children()

    model.apply(fn=weight_reset)

 

torch.manual_seed(42)

model.to_empty(device=device)

reset_all_weights(model)

assert isinstance(model, FSDPModule), f"Expected FSDPModule, got {type(model)}"

 

# Training parameters

epochs = 3

learning_rate = 1e-3

batch_size = 64 // mesh["dp"].size()

seq_length = 512

num_warmup_steps = 1000

PAD_TOKEN_ID = tokenizer.token_to_id("[PAD]")

model.train()

 

# DataLoader, optimizer, scheduler, and loss function

# Sampler is needed to shard the dataset across world size

dataset = PretrainingDataset(dataset, tokenizer, seq_length)

sampler = DistributedSampler(

    dataset, shuffle=False, drop_last=True,

    num_replicas=mesh["dp"].size(),

    rank=mesh["dp"].get_local_rank(),

)

dataloader = torch.utils.data.DataLoader(

    dataset,

    sampler=sampler,

    batch_size=batch_size,

    pin_memory=True,  # optional

    shuffle=False,

    num_workers=2,

    prefetch_factor=2,

)

num_training_steps = len(dataloader) * epochs

 

optimizer = torch.optim.AdamW(

    model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.1,

)

warmup_scheduler = lr_scheduler.LinearLR(

    optimizer,

    start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,

)

cosine_scheduler = lr_scheduler.CosineAnnealingLR(

    optimizer,

    T_max=num_training_steps - num_warmup_steps,

    eta_min=0,

)

scheduler = lr_scheduler.SequentialLR(

    optimizer,

    schedulers=[warmup_scheduler, cosine_scheduler],

    milestones=[num_warmup_steps],

)

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)

 

# if checkpoint-dist dir exists, load the checkpoint to model and optimizer

if os.path.exists("checkpoint-dist"):

    load_checkpoint(model, optimizer, scheduler)

 

# start training

print(f"({rank}) Starting training")

for epoch in range(epochs):

    pbar = tqdm.tqdm(dataloader, desc=f"({rank}) Epoch {epoch+1}/{epochs}")

    for batch_id, batch in enumerate(pbar):

        if batch_id % 1000 == 0:

            save_checkpoint(model, optimizer, scheduler)

        # Explicit prefetching before sending any data to model

        model.unshard()

        # Get batched data, move from CPU to GPU

        input_ids, target_ids = batch

        input_ids = input_ids.to(device)

        target_ids = target_ids.to(device)

        # create attention mask: causal mask + padding mask

        attn_mask = create_causal_mask(input_ids) + \

                    create_padding_mask(input_ids, PAD_TOKEN_ID)

        # Extract output from model

        logits = model(input_ids, attn_mask)

        optimizer.zero_grad()

        with loss_parallel():

            # Compute loss: cross-entropy between logits and target, ignoring padding tokens

            loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1))

            # Backward with loss on DTensor

            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()

        pbar.set_postfix(loss=loss.item())

        pbar.update(1)

    pbar.close()

 

# Save the model

save_checkpoint(model, optimizer, scheduler)

 

# Clean up the distributed environment

dist.destroy_process_group()

Share this article