Přeskočit na obsah
_CORE
AI & Agentic Systems Core Informační Systémy Cloud & Platform Engineering Data Platforma & Integrace Security & Compliance QA, Testing & Observability IoT, Automatizace & Robotika Mobile & Digital Banky & Finance Pojišťovnictví Veřejná správa Obrana & Bezpečnost Zdravotnictví Energetika & Utility Telco & Média Průmysl & Výroba Logistika & E-commerce Retail & Loyalty
Reference Technologie Blog Knowledge Base O nás Spolupráce Kariéra
Pojďme to probrat

Mixture of Experts — efektivní škálování

01. 01. 2024 4 min čtení intermediate

Mixture of Experts představuje revoluci ve škálování AI modelů. Namísto aktivace celé neuronové sítě se používají pouze relevantní “experti” pro konkrétní úkol, což dramaticky snižuje výpočetní náklady při zachování vysoké výkonnosti.

Co je Mixture of Experts (MoE)

Mixture of Experts představuje architektonický přístup, který umožňuje dramatické škálování kapacity neuronových sítí bez proporcionálního nárůstu výpočetních nákladů. Základní myšlenka spočívá v rozdělení modelu na více specializovaných “expertů”, přičemž pro každý vstup se aktivuje pouze podmnožina z nich.

MoE architektura se skládá ze tří klíčových komponent:

  • Expert networks - specializované feed-forward sítě
  • Gating network - router určující, kteří experti se aktivují
  • Sparsity mechanismus - zajišťuje aktivaci pouze top-k expertů

Fungování MoE vrstvy

Tradiční dense vrstva zpracuje vstup přes všechny parametry. MoE vrstva místo toho:

# Pseudokód MoE vrstvy
def moe_layer(x, experts, gate_network):
    # 1. Gating - určení pravděpodobností expertů
    gate_scores = gate_network(x)  # shape: [batch, num_experts]

    # 2. Top-k selekce
    top_k_gates, top_k_indices = topk(gate_scores, k=2)

    # 3. Normalizace gate weights
    top_k_gates = softmax(top_k_gates)

    # 4. Výpočet pouze vybraných expertů
    expert_outputs = []
    for i in range(k):
        expert_idx = top_k_indices[i]
        expert_output = experts[expert_idx](x)
        weighted_output = top_k_gates[i] * expert_output
        expert_outputs.append(weighted_output)

    # 5. Kombinace výstupů
    return sum(expert_outputs)

Klíčová výhoda: při aktivaci pouze 2 z 8 expertů dosáhneme 4x menší výpočetní složitosti než u ekvivalentní dense vrstvy, ale zachováme podobnou expresivitu.

Load Balancing problem

Bez vhodné regularizace má gating network tendenci posílat většinu tokenů k malému počtu expertů. To vede k neefektivnímu využití kapacity a bottleneckům. Řešením je auxiliary loss:

def load_balancing_loss(gate_scores, top_k_indices, num_experts):
    # Frekvence využití expertů
    expert_counts = torch.bincount(top_k_indices.flatten(), 
                                  minlength=num_experts)
    expert_fractions = expert_counts.float() / top_k_indices.numel()

    # Průměrné gate pravděpodobnosti
    gate_means = gate_scores.mean(dim=0)

    # Load balancing loss (chceme uniform distribuci)
    uniform_target = 1.0 / num_experts
    load_loss = num_experts * torch.sum(expert_fractions * gate_means)

    return load_loss

Praktická implementace s PyTorch

Ukázka jednoduché MoE implementace:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)))

class MoELayer(nn.Module):
    def __init__(self, dim, num_experts=8, top_k=2, hidden_dim=None):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        if hidden_dim is None:
            hidden_dim = 4 * dim

        # Gating network
        self.gate = nn.Linear(dim, num_experts, bias=False)

        # Expert networks
        self.experts = nn.ModuleList([
            Expert(dim, hidden_dim) for _ in range(num_experts)
        ])

    def forward(self, x):
        batch_size, seq_len, dim = x.shape
        x_flat = x.view(-1, dim)

        # Gating
        gate_logits = self.gate(x_flat)

        # Top-k selection
        top_k_logits, top_k_indices = torch.topk(
            gate_logits, self.top_k, dim=-1
        )
        top_k_gates = F.softmax(top_k_logits, dim=-1)

        # Initialize output
        output = torch.zeros_like(x_flat)

        # Process each expert
        for i in range(self.top_k):
            expert_mask = top_k_indices[:, i]
            expert_weights = top_k_gates[:, i:i+1]

            for expert_idx in range(self.num_experts):
                token_mask = expert_mask == expert_idx
                if token_mask.any():
                    expert_tokens = x_flat[token_mask]
                    expert_output = self.experts[expert_idx](expert_tokens)
                    output[token_mask] += expert_weights[token_mask] * expert_output

        return output.view(batch_size, seq_len, dim)

MoE v produkčních modelech

Největší úspěch zaznamenala MoE architektura v jazykových modelech. Mixtral 8x7B představuje výrazný milestone - dosahuje výkonu srovnatelného s mnohem většími dense modely:

  • Mixtral 8x7B: 8 expertů, aktivace 2, celkem 46.7B parametrů
  • Efektivní parametry: 12.9B aktivních při inferenci
  • Výkon: konkuruje modelům s 70B+ parametry

Deployment considerations

MoE modely přinášejí specifické výzvy v produkci:

# Memory requirements calculation
def calculate_moe_memory(
    num_experts, expert_params, 
    active_experts, batch_size
):
    # Všichni experti musí být v paměti
    total_params = num_experts * expert_params

    # Ale výpočet jen pro aktivní experty
    compute_params = active_experts * expert_params

    # Memory pro gradients (pokud trénujeme)
    gradient_memory = total_params * 4  # float32

    # Aktivační memory
    activation_memory = batch_size * compute_params * 4

    return {
        'model_memory': total_params * 4,
        'gradient_memory': gradient_memory,
        'activation_memory': activation_memory,
        'total_training': total_params * 12  # params + grads + optimizer states
    }

Klíčové optimalizace zahrnují expert parallelism, kde různé GPU hostují různé experty, a komunikaci pouze aktivních expertů.

Výhody a omezení MoE

Výhody:

  • Škálování kapacity bez lineárního růstu FLOPs
  • Specializace - experti se učí různé aspekty dat
  • Inference efektivita - konstantní výpočetní složitost

Omezení:

  • Memory overhead - všichni experti v paměti
  • Load balancing - složitá optimalizace využití
  • Communication cost - při distribuovaném tréningu
  • Routing collapse - tendence využívat jen některé experty

Pokročilé techniky

Moderní MoE implementace používají sofistikovanější přístupy:

class SwitchMoE(nn.Module):
    """Switch Transformer - top-1 routing s capacity faktorem"""

    def __init__(self, dim, num_experts, capacity_factor=1.25):
        super().__init__()
        self.capacity_factor = capacity_factor
        self.num_experts = num_experts

        self.gate = nn.Linear(dim, num_experts)
        self.experts = nn.ModuleList([
            Expert(dim) for _ in range(num_experts)
        ])

    def forward(self, x):
        # Capacity - max tokenů na experta
        capacity = int(self.capacity_factor * x.size(0) / self.num_experts)

        # Top-1 gating
        gate_logits = self.gate(x)
        gates = F.softmax(gate_logits, dim=-1)

        # Expert assignment
        expert_indices = torch.argmax(gates, dim=-1)
        expert_weights = torch.max(gates, dim=-1)[0]

        # Capacity-based batching
        outputs = []
        for expert_idx in range(self.num_experts):
            mask = expert_indices == expert_idx
            tokens = x[mask][:capacity]  # Capacity limit
            weights = expert_weights[mask][:capacity]

            if tokens.size(0) > 0:
                expert_out = self.experts[expert_idx](tokens)
                outputs.append((expert_out * weights.unsqueeze(-1), mask))

        return self._combine_outputs(outputs, x.shape)

Shrnutí

Mixture of Experts představuje elegantní řešení pro škálování neuronových sítí bez proporcionálního nárůstu výpočetních nákladů. Klíčem k úspěchu je správná implementace gating mechanismu, load balancingu a efektivní využití distribuované architektury. Pro produkční nasazení je důležité zvážit memory requirements a komunikační overhead, ale potenciální výhody v podobě lepšího poměru výkon/cena činí MoE atraktivní volbou pro velké jazykové modely.

moemixtralarchitektura