Přeskočit na obsah
_CORE
AI & Agentic Systems Core Informační Systémy Cloud & Platform Engineering Data Platforma & Integrace Security & Compliance QA, Testing & Observability IoT, Automatizace & Robotika Mobile & Digital Banky & Finance Pojišťovnictví Veřejná správa Obrana & Bezpečnost Zdravotnictví Energetika & Utility Telco & Média Průmysl & Výroba Logistika & E-commerce Retail & Loyalty
Reference Technologie Blog Knowledge Base O nás Spolupráce Kariéra
Pojďme to probrat

Loss funkce — přehled a výběr pro váš model

01. 01. 2024 4 min čtení intermediate

Loss funkce jsou klíčovým prvkem tréninku machine learning modelů - určují, jak model měří své chyby a učí se je opravovat. Správný výběr loss funkce může dramaticky ovlivnit výkon vašeho modelu. Projdeme si nejpoužívanější typy a ukážeme, kdy kterou použít.

Co jsou loss funkce a proč jsou klíčové

Loss funkce (ztrátové funkce) představují srdce každého machine learning modelu. Definují, jak měříme “vzdálenost” mezi predikovanými a skutečnými hodnotami, a tím přímo ovlivňují, jak se model učí. Správný výběr loss funkce může znamenat rozdíl mezi modelem, který funguje výborně, a tím, který nikdy nekonverguje.

V podstatě jde o matematickou formulaci toho, co považujeme za “chybu”. Model během trénování minimalizuje tuto chybu pomocí optimalizačních algoritmů jako SGD nebo Adam. Různé typy problémů vyžadují různé loss funkce – to, co funguje pro klasifikaci, nemusí být vhodné pro regresi.

Regression loss funkce

Mean Squared Error (MSE)

MSE je nejpopulárnější loss funkcí pro regresní úlohy. Počítá průměr čtverců rozdílů mezi predikovanými a skutečnými hodnotami:

import torch
import torch.nn as nn

# PyTorch implementace
mse_loss = nn.MSELoss()
predictions = torch.tensor([2.5, 0.0, 2.1])
targets = torch.tensor([3.0, -0.5, 2.0])
loss = mse_loss(predictions, targets)
print(f"MSE Loss: {loss.item()}")

# Manuální implementace
def mse_manual(y_pred, y_true):
    return torch.mean((y_pred - y_true) ** 2)

Výhody MSE: Jednoduché na implementaci, silně penalizuje velké chyby díky umocnění na druhou. Nevýhody: Citlivé na outliery, které mohou výrazně zkreslit trénování.

Mean Absolute Error (MAE)

MAE počítá průměr absolutních hodnot rozdílů. Je robustnější vůči outlierům než MSE:

mae_loss = nn.L1Loss()  # L1Loss = MAE v PyTorch
loss = mae_loss(predictions, targets)

# Manuální implementace
def mae_manual(y_pred, y_true):
    return torch.mean(torch.abs(y_pred - y_true))

Huber Loss

Huber Loss kombinuje výhody MSE a MAE. Chová se jako MSE pro malé chyby a jako MAE pro velké chyby:

huber_loss = nn.HuberLoss(delta=1.0)
loss = huber_loss(predictions, targets)

# Manuální implementace
def huber_loss_manual(y_pred, y_true, delta=1.0):
    error = torch.abs(y_pred - y_true)
    is_small_error = error <= delta
    squared_loss = 0.5 * error ** 2
    linear_loss = delta * error - 0.5 * delta ** 2
    return torch.mean(torch.where(is_small_error, squared_loss, linear_loss))

Classification loss funkce

Cross-Entropy Loss

Cross-entropy je standardní volbou pro klasifikační úlohy. Měří “vzdálenost” mezi pravděpodobnostními distribucemi:

# Binární klasifikace
binary_ce = nn.BCELoss()
sigmoid_output = torch.sigmoid(torch.tensor([0.8, -1.2, 2.1]))
binary_targets = torch.tensor([1.0, 0.0, 1.0])
loss = binary_ce(sigmoid_output, binary_targets)

# Multi-class klasifikace
ce_loss = nn.CrossEntropyLoss()
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.2]])
targets = torch.tensor([0, 1])  # Class indices
loss = ce_loss(logits, targets)

Cross-entropy má důležitou vlastnost – rychle konverguje, když je model velmi špatný, ale zpomaluje, když se blíží k optimu. To vede k stabilnímu učení.

Focal Loss

Focal Loss řeší problém nevyvážených datasetů tím, že snižuje váhu “jednoduchých” příkladů a zaměřuje se na složité případy:

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

# Použití
focal_loss = FocalLoss(alpha=1, gamma=2)
loss = focal_loss(logits, targets)

Pokročilé loss funkce

Contrastive Loss

Používá se v metric learning, kde chceme naučit model rozpoznávat podobnost mezi objekty:

def contrastive_loss(output1, output2, label, margin=1.0):
    euclidean_distance = nn.functional.pairwise_distance(output1, output2)
    loss_contrastive = torch.mean(
        (1-label) * torch.pow(euclidean_distance, 2) +
        label * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2)
    )
    return loss_contrastive

Dice Loss

Populární v segmentačních úlohách, zejména v medicínském imaging:

def dice_loss(inputs, targets, smooth=1):
    inputs = torch.sigmoid(inputs)
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

    return 1 - dice

Praktické tipy pro výběr loss funkce

Výběr správné loss funkce závisí na několika faktorech:

  • Typ problému: Regrese vs. klasifikace vs. ranking
  • Distribuce dat: Vyvážené vs. nevyvážené třídy
  • Citlivost na outliery: MSE vs. MAE pro regresi
  • Interpretabilita: Některé loss funkce mají jasný statistický význam

Pro debugging a monitoring doporučuji sledovat více metrik současně:

# Kombinace více loss funkcí pro lepší insight
class CombinedLoss(nn.Module):
    def __init__(self, weights={'mse': 0.7, 'mae': 0.3}):
        super().__init__()
        self.weights = weights
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()

    def forward(self, predictions, targets):
        mse_loss = self.mse(predictions, targets)
        mae_loss = self.mae(predictions, targets)

        total_loss = (self.weights['mse'] * mse_loss + 
                     self.weights['mae'] * mae_loss)

        return total_loss, {'mse': mse_loss.item(), 'mae': mae_loss.item()}

Nezapomeňte experimentovat! Často se vyplatí začít se standardními funkcemi (MSE pro regresi, Cross-Entropy pro klasifikaci) a postupně optimalizovat podle specifických potřeb vašeho problému.

Shrnutí

Loss funkce jsou fundamentálním stavebním kamenem machine learning modelů. Pro regresní úlohy začněte s MSE nebo MAE, pro klasifikaci s Cross-Entropy. Pokročilé funkce jako Focal Loss nebo Dice Loss řeší specifické problémy jako nevyvážené datasety nebo segmentační úlohy. Klíčem je experimentování a porozumění tomu, jak různé funkce ovlivňují chování vašeho modelu. Sledujte více metrik současně a nezapomeňte na validační data při vyhodnocování výkonu.

loss functionsmsecross-entropy