Přeskočit na obsah
_CORE
AI & Agentic Systems Core Informační Systémy Cloud & Platform Engineering Data Platforma & Integrace Security & Compliance QA, Testing & Observability IoT, Automatizace & Robotika Mobile & Digital Banky & Finance Pojišťovnictví Veřejná správa Obrana & Bezpečnost Zdravotnictví Energetika & Utility Telco & Média Průmysl & Výroba Logistika & E-commerce Retail & Loyalty
Reference Technologie Blog Knowledge Base O nás Spolupráce Kariéra
Pojďme to probrat

Transfer Learning — využití předtrénovaných modelů

01. 01. 2024 4 min čtení intermediate

Transfer Learning je technika, která umožňuje využít znalosti naučené na jednom úkolu pro řešení jiného, podobného problému. Namísto trénování modelu od nuly můžete využít předtrénované modely a přizpůsobit je vašim specifickým potřebám.

Co je Transfer Learning

Transfer Learning představuje jednu z nejefektivnějších technik v moderním strojovém učení. Místo trénování modelu od nuly využíváme znalosti již naučené na rozsáhlých datasetech a přizpůsobíme je našemu specifickému problému. Tento přístup šetří čas, výpočetní zdroje a často dosahuje lepších výsledků než klasické učení od začátku.

Základní myšlenka je jednoduchá: model, který se naučil rozpoznávat obecné vzory v datech (například hrany, textury, nebo jazykové struktury), může tyto znalosti aplikovat i na příbuzné úlohy. Stačí jen “doladit” poslední vrstvy pro naši specifickou doménu.

Typy Transfer Learning

Rozlišujeme několik hlavních přístupů:

  • Feature Extraction - zmrazíme váhy předtrénovaného modelu a použijeme ho jako feature extraktor
  • Fine-tuning - postupně odemkneme a dotrénujeme některé vrstvy na našich datech
  • Domain Adaptation - přizpůsobíme model novému typu dat (např. z fotografií na kreslené obrázky)

Feature Extraction v praxi

Nejjednodušší přístup využívá předtrénovaný model jako black box pro extrakci příznaků:

import torch
import torchvision.models as models
from torch import nn

# Načteme předtrénovaný ResNet
base_model = models.resnet50(pretrained=True)

# Zmrazíme všechny parametry
for param in base_model.parameters():
    param.requires_grad = False

# Nahradíme classifier pro naši úlohu (např. 10 tříd)
base_model.fc = nn.Linear(base_model.fc.in_features, 10)

# Pouze nová vrstva se bude trénovat
optimizer = torch.optim.Adam(base_model.fc.parameters(), lr=0.001)

Postupný Fine-tuning

Sofistikovanější přístup postupně “odemyká” vrstvy pro dotrénování:

class TransferModel(nn.Module):
    def __init__(self, num_classes, freeze_layers=True):
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)

        if freeze_layers:
            # Zmrazíme první vrstvy
            for param in self.backbone.layer1.parameters():
                param.requires_grad = False
            for param in self.backbone.layer2.parameters():
                param.requires_grad = False

        # Upravíme classifier
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

    def unfreeze_layers(self, layer_names):
        """Postupné odmrazování vrstev"""
        for name in layer_names:
            layer = getattr(self.backbone, name)
            for param in layer.parameters():
                param.requires_grad = True

model = TransferModel(num_classes=10)

# Po několika epochách můžeme odzmrazit další vrstvy
model.unfreeze_layers(['layer2', 'layer3'])

Transfer Learning pro NLP

V oblasti zpracování přirozeného jazyka je transfer learning ještě důležitější. Modely jako BERT, GPT nebo RoBERTa jsou natrénovány na obrovských textových korpusech a dokáží zachytit složité jazykové vzory.

Fine-tuning BERT pro klasifikaci

from transformers import BertForSequenceClassification, BertTokenizer
from transformers import TrainingArguments, Trainer

# Načteme předtrénovaný BERT
model = BertForSequenceClassification.from_pretrained(
    'bert-base-multilingual-cased',
    num_labels=3  # Např. sentiment: pozitivní, negativní, neutrální
)

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Příprava dat
def tokenize_function(examples):
    return tokenizer(
        examples['text'], 
        truncation=True, 
        padding=True, 
        max_length=512
    )

train_dataset = train_dataset.map(tokenize_function, batched=True)

# Nastavení tréninku
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    learning_rate=2e-5,
    warmup_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Osvědčené postupy

Výběr learning rate

Při fine-tuningu je klíčové správně nastavit learning rate. Obecně platí:

  • Pro nové vrstvy: vyšší learning rate (1e-3 až 1e-4)
  • Pro předtrénované vrstvy: nižší learning rate (1e-5 až 1e-6)
  • Postupné snižování s pokračujícím tréninkem
# Différenciovaný learning rate pro různé části modelu
def get_optimizer_grouped_parameters(model, backbone_lr=1e-5, head_lr=1e-3):
    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.backbone.named_parameters() 
                      if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
            "lr": backbone_lr
        },
        {
            "params": [p for n, p in model.backbone.named_parameters() 
                      if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
            "lr": backbone_lr
        },
        {
            "params": model.fc.parameters(),
            "lr": head_lr
        }
    ]

    return torch.optim.AdamW(optimizer_grouped_parameters)

Data Augmentation a regularizace

Při menších datasetech je důležité předcházet overfittingu:

import torchvision.transforms as transforms

# Augmentace pro computer vision
transform = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# Dropout v custom vrstvách
class FineTunedModel(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.backbone = base_model
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(base_model.fc.in_features, num_classes)

    def forward(self, x):
        features = self.backbone.features(x)
        pooled = nn.AdaptiveAvgPool2d((1, 1))(features)
        flattened = torch.flatten(pooled, 1)
        dropped = self.dropout(flattened)
        return self.classifier(dropped)

Praktické tipy pro úspěšný transfer

Podobnost domén: Čím podobnější je zdrojová a cílová doména, tím lepší výsledky můžeme očekávat. Model trénovaný na obecných fotografiích se lépe přizpůsobí lékařským snímkům než satelitním datům.

Velikost datasetu: Pro malé datasety (stovky vzorků) je feature extraction bezpečnější volba. Pro větší datasety (tisíce vzorků) můžeme experimentovat s fine-tuningem.

Gradual unfreezing: Místo okamžitého odemčení všech vrstev postupně odmrazujeme od vrchních k spodním vrstvám:

def gradual_unfreeze_schedule(model, epoch):
    """Postupné odmrazování podle epochy"""
    if epoch >= 5:
        # Od 5. epochy odmrazíme vrchní vrstvy
        for param in model.backbone.layer4.parameters():
            param.requires_grad = True

    if epoch >= 10:
        # Od 10. epochy další vrstvy
        for param in model.backbone.layer3.parameters():
            param.requires_grad = True

    # Learning rate také postupně snižujeme
    if epoch >= 5:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.5

Shrnutí

Transfer Learning představuje fundamental změnu v přístupu k strojovému učení. Místo trénování modelů od nuly využíváme kolektivní “znalosti” uložené v předtrénovaných modelech. Klíčem k úspěchu je správná volba strategie (feature extraction vs. fine-tuning), pečlivé nastavení learning rates pro různé části modelu a postupný přístup k odmrazování vrstev. S rostoucími předtrénovanými modely jako jsou foundation models se transfer learning stává ještě důležitějším nástrojem pro efektivní vývoj AI aplikací.

transfer learningpre-trainingfine-tuning