LLM Surgery:高效的大型語言模型知識遺忘與編輯

LLM Surgery 是一種高效的框架,通過優化三個目標函數(unlearn、update、retain)來修改 LLM 行為,實現問題知識的遺忘和新知識的整合,無需從頭重新訓練。

LLM Surgery:高效的大型語言模型知識遺忘與編輯

研究日期

2026-02-14

簡介

LLM Surgery 是一種創新的框架,用於高效地修改大型語言模型(LLM)的行為。它解決了 LLM 訓練後面臨的「不可逆困境」——一旦預訓練完成,模型權重幾乎無法修改。當出現過時事實、惡意 prompt 注入樣本、版權或隱私問題、法律合規風險,或需要更新產品知識時,傳統做法只有兩種選擇:從頭重訓(full retraining)或持續 pretrain(continued pretraining),這些方法的代價極高(百萬~千萬 GPU 小時級別)。

LLM Surgery 提出了一種工程化的落地方案,通過在權重空間做局部優化,而不是整體重訓,實現同時進行知識遺忘(unlearning)、知識更新(updating)和知識保留(retention)。在 Llama2-7B 上的實驗表明,LLM Surgery 可以在 unlearn set 上實現顯著的遺忘,在 update set 上提升 20% 的準確率,並在 retain set 上維持原有性能。

主要特性

  • 三重目標優化:同時優化 unlearn(反向梯度)、update(正常梯度下降)、retain(KL 約束)三個目標函數
  • 高效性:無需從頭重訓,直接使用 SGD + KL constraint,工程成本低
  • 可擴展性:不依賴精準定位 neuron,更 generalizable,可同時處理多個知識點
  • 穩定性:通過 KL 約束避免 catastrophic forgetting,保持模型原有能力

工作原理

核心思想

LLM Surgery 的核心思想是在權重空間做局部優化,而不是整體重訓。它通過定義三個資料集,並對每個資料集應用不同的優化策略:

  1. Unlearn Set(D_u):要忘記的資料(過時、錯誤或敏感資訊)
  2. Update Set(D_up):要學習的新資料
  3. Retain Set(D_r):要保留的資料(模型原有知識)

數學形式化

總目標函數為:

$$ \mathcal{L} = \lambda_{\text{u}} \mathcal{L}_{\text{unlearn}} + \lambda_{\text{up}} \mathcal{L}_{\text{update}} + \lambda_{\text{r}} \mathcal{L}_{\text{retain}} $$

1️⃣ Unlearning(反向梯度)

對要忘記的資料使用 gradient ascent(梯度上升):

$$ \mathcal{L}_{\text{unlearn}} = - \log P_{\theta}(y|x) $$

直觀理解:

  • 提高 loss
  • 降低模型對該知識的置信度
  • 讓模型遠離該分佈

這本質上是 局部反訓練(anti-training)

2️⃣ Updating(正常梯度下降)

對新知識使用標準交叉熵:

$$ \mathcal{L}_{\text{update}} = \log P_{\theta}(y_{\text{new}}|x) $$

即標準微調(fine-tuning)。

3️⃣ Retention(KL 約束)

為避免 catastrophic forgetting(災難性遺忘):

$$ \mathcal{L}_{\text{retain}} = D_{\text{KL}} \left( P_{\theta_{\text{orig}}}(y|x) || P_{\theta}(y|x) \right) $$

強迫:編輯後模型 ≈ 原模型(在保留集上)。

這是一種 distribution matching,通過最小化 KL 散度來確保編輯後的模型在保留集上的輸出分佈與原模型保持一致。

架構流程圖

  flowchart TD
    A[原始模型<br/>θ_orig] --> B[定義三個資料集]
    B --> C[Unlearn Set<br/>D_u]
    B --> D[Update Set<br/>D_up]
    B --> E[Retain Set<br/>D_r]
    C --> F[反向梯度<br/>λ_u · L_unlearn]
    D --> G[正常梯度下降<br/>λ_up · L_update]
    E --> H[KL 約束<br/>λ_r · L_retain]
    F --> I[加權優化<br/>L = λ_uL_u + λ_upL_up + λ_rL_r]
    G --> I
    H --> I
    I --> J[編輯後模型<br/>θ_edited]
    J --> K{評估結果}
    K -->|Unlearn Set<br/>降低準確率| L[✅ 遺忘成功]
    K -->|Update Set<br/>提升準確率| M[✅ 更新成功]
    K -->|Retain Set<br/>維持性能| N[✅ 保留成功]

主要使用案例

LLM Surgery 的應用場景主要包括:

  • 法律合規:移除版權或隱私敏感資料
  • 知識更新:更新過時事實(例如:政治人物職位、公司資訊)
  • 安全防禦:移除惡意 prompt 注入樣本
  • 產品迭代:快速更新產品知識庫
  • 隱私保護:刪除個人資訊(GDPR 合規)

與其他技術的差異

與 LoRA 微調的差異

技術 是否修改原權重 目的
LoRA 添加新能力
Fine-tune 整體調整
LLM Surgery 是(局部) 刪除 + 新增 + 保留

Surgery 更偏向「權重編輯」。

與 ROME / MEMIT 的比較

ROME / MEMIT:

  • 直接改變 transformer 中的 MLP 權重
  • 定位特定 fact 存儲位置
  • 單點知識重寫

LLM Surgery:

  • 不需要精準定位 neuron
  • 更 generalizable
  • 可同時 unlearn + update

與 Machine Unlearning 的比較

Machine Unlearning 多數方法:

  • 近似 retraining
  • 需要 influence function
  • 計算昂貴

LLM Surgery:

  • 直接使用 SGD + KL constraint
  • 工程成本低

實驗設計框架

Dataset 切分

  flowchart LR
    A[Full Dataset] --> B[Unlearn Set<br/>~5%]
    A --> C[Update Set<br/>~5%]
    A --> D[Retain Set<br/>~90%]

常見比例設計:

  • Unlearn: 5%
  • Update: 5%
  • Retain: 90%

Evaluation Metrics

1️⃣ Forgetting Score(遺忘分數)

$$ \Delta Acc_{\text{unlearn}} $$

理想情況:準確率接近隨機(表示完全遺忘)。

2️⃣ Update Accuracy(更新準確率)

$$ Acc_{\text{update}} $$

應顯著提升(表示成功學習新知識)。

3️⃣ Retention Gap(保留差距)

$$ |Acc_{\text{retain}}^{orig} - Acc_{\text{retain}}^{edited}| $$

應趨近 0(表示原有知識被保留)。

4️⃣ KL Drift(KL 偏移)

測量整體分佈偏移量,評估模型穩定性。

安裝與設定

前置需求

  • Python 3.8+
  • PyTorch 2.0+
  • Transformers 套件
  • Hugging Face Hub 存取權限

安裝步驟

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
# 安裝依賴
pip install torch transformers accelerate

# 載入預訓練模型
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from copy import deepcopy

# 載入 frozen model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# 保存原始模型副本(用於 KL 約束)
orig_model = deepcopy(model).eval()

實作核心代碼

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch.nn.functional as F
from torch.optim import AdamW

# 定義三種 loss
def compute_losses(model, orig_model, unlearn_data, update_data, retain_data):
    """計算三種 loss"""
    # 1. Unlearning loss(反向梯度)
    with torch.set_grad_enabled(True):
        outputs_unlearn = model(**unlearn_data)
        logits_unlearn = outputs_unlearn.logits
        loss_unlearn = -F.cross_entropy(
            logits_unlearn.view(-1, logits_unlearn.size(-1)),
            unlearn_data["labels"].view(-1)
        )

    # 2. Update loss(正常梯度下降)
    with torch.set_grad_enabled(True):
        outputs_update = model(**update_data)
        logits_update = outputs_update.logits
        loss_update = F.cross_entropy(
            logits_update.view(-1, logits_update.size(-1)),
            update_data["labels"].view(-1)
        )

    # 3. Retain loss(KL 約束)
    with torch.no_grad():
        orig_outputs = orig_model(**retain_data)
        orig_logits = orig_outputs.logits

    with torch.set_grad_enabled(True):
        edited_outputs = model(**retain_data)
        edited_logits = edited_outputs.logits

    # KL divergence
    loss_retain = F.kl_div(
        F.log_softmax(edited_logits.view(-1, edited_logits.size(-1)), dim=-1),
        F.softmax(orig_logits.view(-1, orig_logits.size(-1)), dim=-1),
        reduction="batchmean"
    )

    return loss_unlearn, loss_update, loss_retain

# 設定 optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)

# 加權優化
def train_step(model, orig_model, unlearn_data, update_data, retain_data,
               lambda_u=1.0, lambda_up=1.0, lambda_r=0.5):
    loss_unlearn, loss_update, loss_retain = compute_losses(
        model, orig_model, unlearn_data, update_data, retain_data
    )

    # 加權總 loss
    total_loss = (lambda_u * loss_unlearn +
                  lambda_up * loss_update +
                  lambda_r * loss_retain)

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    return total_loss.item(), {
        "loss_unlearn": loss_unlearn.item(),
        "loss_update": loss_update.item(),
        "loss_retain": loss_retain.item()
    }

# 訓練循環
num_epochs = 10
for epoch in range(num_epochs):
    for batch in dataloader:
        unlearn_data = batch["unlearn"]
        update_data = batch["update"]
        retain_data = batch["retain"]

        total_loss, losses = train_step(
            model, orig_model, unlearn_data, update_data, retain_data,
            lambda_u=1.0, lambda_up=1.0, lambda_r=0.5
        )

        print(f"Epoch {epoch}, Loss: {total_loss:.4f}, "
              f"Unlearn: {losses['loss_unlearn']:.4f}, "
              f"Update: {losses['loss_update']:.4f}, "
              f"Retain: {losses['loss_retain']:.4f}")

最佳實踐

超參數設定

  • lambda_u(unlearn 權重):1.0-2.0,控制遺忘強度
  • lambda_up(update 權重):1.0,控制更新強度
  • lambda_r(retain 權重):0.5-1.0,防止 catastrophic forgetting
  • learning rate:1e-5 - 5e-5,較小的學習率避免過大改動
  • epochs:5-20,避免過度訓練

實作建議

  1. 更新步數少:避免過度優化,本質上是 small perturbation in weight space
  2. 使用小 learning rate:確保局部修改不會影響整體模型
  3. Retain Set 選擇:使用多樣化的保留集,確保模型能力不被退化
  4. 評估頻率:定期在三個資料集上評估,及時調整超參數

進階功能

多次疊加 Surgery

可以多次執行 Surgery 操作,但需要注意:

  • 累積誤差可能導致 drift
  • 建議保留中間 checkpoint
  • 定期評估整體模型性能

與 LoRA 結合

可以將 Surgery 應用於 LoRA 適配器:

  • 在 LoRA 層執行 unlearn/update
  • 保持基礎模型不變
  • 更容易實現可逆操作

為什麼不會崩壞模型?

LLM Surgery 不會導致模型崩壞的關鍵原因:

  1. 更新步數少:只進行局部微調
  2. KL 約束:強迫模型在保留集上保持原分佈
  3. 小 learning rate:避免過大的權重改動
  4. 局部資料:只在特定資料集上優化

本質上是 small perturbation in weight space,在 loss landscape 中做 constrained optimization:

  • Move away from subspace A(unlearn set)
  • Move toward subspace B(update set)
  • Stay near original basin(retain set)

演進脈絡

2022–2023:Knowledge Editing

  • ROME(Rank-One Model Editing)
  • MEMIT(Mass Editing Memory in a Transformer)
  • SERAC
  • MEND

目標:改一個 fact。

2023–2024:Machine Unlearning for LLM

  • Selective forgetting
  • Privacy removal
  • Influence function-based removal

重點:合法刪除資料。

2024:LLM Surgery

重點

  • 三目標函數(unlearn + update + retain)
  • 同時進行知識遺忘和更新
  • 低成本實現

2025 之後方向

研究重點轉向:

  • Selective token-level unlearning
  • Circuit localization
  • Causal tracing
  • Weight subspace editing
  • Adapter-based unlearning
  • Continual knowledge editing

成本對比分析

商業價值

方法 成本 風險 可持續更新
重訓 極高
持續 pretrain
LoRA 不刪除
Surgery

LLM Surgery 的商業價值:

  • 法律風險快速修補:快速移除敏感資料
  • 模型版本持續更新:無需停機大規模重訓
  • 成本效益:比重訓低 2-3 個數量級

風險與限制

  1. 無法保證完全刪除:可能 latent 保留部分知識
  2. 分佈外推可能復活舊知識:在相似情境下可能回憶起已刪除的知識
  3. 多次手術可能產生 drift:累積誤差導致模型性能退化
  4. 未解決 representation-level 交纏:知識在模型內部交纏,難以精確刪除

進階分析

權重空間觀點

Surgery 本質:在 loss landscape 中做 constrained optimization。

目標

  • Move away from subspace A(unlearn)
  • Move toward subspace B(update)
  • Stay near original basin(retain)

這類似:Projection in parameter space under KL regularization

Fisher Information 關聯

Retention 約束可視為:保持高 Fisher direction 不變。

Fisher Information Matrix 衡量參數對模型輸出的敏感度,KL 約束實際上是在保持對重要方向的敏感度不變。

與 EWC 關聯

Elastic Weight Consolidation (EWC):

$$ \sum_i F_i (\theta_i - \theta_i^*)^2 $$

EWC 通過約束重要參數的變化來防止 catastrophic forgetting,與 KL 約束高度相關。LLM Surgery 的 KL 約束實際上是 EWC 在分佈層面的擴展。

進階分析圖示

  flowchart LR
    A[權重空間<br/>Parameter Space] --> B{LLM Surgery}
    B --> C[Unlearn<br/>遠離 Subspace A]
    B --> D[Update<br/>靠近 Subspace B]
    B --> E[Retain<br/>KL 約束<br/>保持 Original Basin]
    C --> F[Loss Landscape]
    D --> F
    E --> F
    F --> G[Constrained Optimization<br/>受限優化]
    G --> H[Small Perturbation<br/>微小擾動]
    H --> I[模型穩定性<br/>Model Stability]

常見問題解答

Q: LLM Surgery 與傳統 fine-tuning 有什麼區別?

A: 傳統 fine-tuning 只做單向優化(提升某個任務的性能),而 LLM Surgery 同時進行三個方向的優化:遺忘(反向梯度)、更新(正向梯度)、保留(KL 約束)。這使得它能夠同時刪除舊知識、添加新知識,並保持模型原有能力。

Q: 如何評估 Surgery 是否成功?

A: 通過三個指標評估:

  1. Forgetting Score:unlearn set 上準確率應接近隨機
  2. Update Accuracy:update set 上準確率應顯著提升
  3. Retention Gap:retain set 上準確率應與原模型接近

Q: LLM Surgery 適用於哪些場景?

A: 主要適用於需要快速修改模型知識的場景,如:法律合規(移除敏感資料)、知識更新(更正過時資訊)、安全防禦(移除惡意樣本)、產品迭代(更新產品知識)。

Q: Surgery 是否可以多次執行?

A: 理論上可以,但多次疊加可能產生累積誤差(drift)。建議保留中間 checkpoint,並定期評估整體模型性能。

Q: 與 ROME/MEMIT 相比,LLM Surgery 的優勢是什麼?

A: LLM Surgery 不需要精準定位特定 neuron,更 generalizable,可以同時處理多個知識點,且實現更簡單(直接使用 SGD + KL constraint)。ROME/MEMIT 適合精確修改單一 fact,而 Surgery 適合大規模知識編輯。

參考資料

可延伸討論主題

  • 與 LoRA 結合:如何在 LoRA 適配器上執行 Surgery?
  • 線上更新(online surgery):是否可以在運行時動態修改模型?
  • 多次疊加:如何量化多次手術的累積影響?
  • 完全遺忘的量化:如何定義和測量「完全忘記」?
  • 安全防禦應用:如何使用 Surgery 進行 anti-prompt injection?

總結

LLM Surgery 提供了一種高效、實用的框架來修改大型語言模型的知識,解決了傳統 LLM 訓練後面臨的不可逆困境。通過同時優化 unlearn、update、retain 三個目標函數,LLM Surgery 能夠在不進行大規模重訓練的情況下,實現知識的遺忘、更新和保留。這對於需要快速響應法律合規、知識更新和安全防禦的場景具有重要價值,商業應用前景廣闊。