LLM 사전학습에서 모든 샘플을 동등하게 취급하는 관행은 낭비다 — 낮은 손실 샘플을 자동으로 줄여가며 배치마다 중요도를 재조정하면, 추가 비용 없이 더 빠른 수렴과 더 높은 성능을 동시에 얻을 수 있다.
arXiv 2502.06733 →현재 LLM 사전학습 패러다임은 크게 두 단계로 이루어집니다. 데이터 큐레이션(휴리스틱 필터, 품질 분류기, 수동 검토)으로 코퍼스를 구성하고, 이어서 균일 샘플링(uniform sampling)으로 학습합니다. 이 중 데이터 큐레이션이 핵심 작업처럼 여겨지지만, 여기에는 두 가지 근본적 한계가 있습니다.
이를 보완하려는 그룹 수준 재가중(group-level reweighting) 연구들 — DoReMi(Xie et al., 2023), DoGE(Fan et al., 2023) — 이 등장했습니다. 이 방법들은 도메인 단위로 샘플링 비율을 조정합니다. 그러나 이 역시 한계가 있습니다.
더 나아가, 컴퓨터 비전에서 효과적이었던 인스턴스 수준 재가중 기법들(Loshchilov & Hutter 2015; Jiang et al. 2019)도 LLM 사전학습에는 적합하지 않습니다. 그 이유가 이 논문이 해결하려는 핵심 기술 과제입니다.
"How can we dynamically leverage instance-level information to accelerate training and improve model performance without incurring significant computational overhead while potentially reducing the need for extensive data curation?"어떻게 하면 막대한 계산 오버헤드 없이, 데이터 큐레이션 필요성도 줄이면서, 인스턴스 수준 정보를 동적으로 활용하여 학습을 가속하고 모델 성능을 개선할 수 있을까?
이 논문의 핵심 통찰은 간단합니다: 현재 배치 내 각 샘플의 손실값(loss)은 그 샘플의 현재 시점 중요도를 알려주는 즉각적인 신호다. 이미 잘 아는 것(낮은 손실)보다 아직 모르는 것(높은 손실)에 더 집중하면 학습이 빨라진다 — 이미 안다면 그 반복에서 배울 게 없기 때문입니다.
직관: 모델이 이미 잘 예측하는 샘플(낮은 손실)은 두 가지 경우입니다. ① 해당 패턴을 충분히 배웠거나, ② 그 샘플이 너무 단순/반복적이어서 처음부터 배울 것이 없는 경우. 어느 쪽이든 이 스텝에서 해당 샘플의 그래디언트 기여는 낮습니다. 반면 높은 손실 샘플은 아직 모델이 이해하지 못한 패턴을 포함하고 있습니다.
보충 이는 컴퓨터 비전에서의 "Hard Example Mining"과 유사한 직관이지만, LLM에서는 한 번만 보는 데이터 특성상 과거 통계 없이 현재 배치만으로 즉시 결정해야 한다는 결정적 차이가 있습니다.
손실이 극단적으로 높은 샘플은 노이즈나 이상치(outlier)일 가능성이 높습니다. 이런 샘플에 너무 큰 가중치를 주면 훈련이 불안정해집니다. 논문은 이를 이론적으로도 뒷받침합니다: 가중치에 상한선 w_i ≤ 2/M을 두면 수렴 보장이 가능합니다.
정규화된 손실값 h ([-α, α] 범위)에 대해 각 전략이 어떤 모양의 가중치를 부여하는지 슬라이더로 확인하세요.
x축: 정규화된 손실값 h (낮을수록 쉬운 샘플). y축: 소프트맥스 후 최종 중요도 가중치. LinUpper는 낮은 손실 샘플의 가중치를 일관되게 낮추는 반면, Extremes는 양쪽 극단을 모두 강조합니다.
학습 루프의 각 스텝 t에서 일어나는 일을 블록으로 표현했습니다. 블록을 클릭하면 세부 설명을 볼 수 있습니다.
| 변수 | 의미 | 비고 |
|---|---|---|
θ_t | 시간 t에서의 모델 파라미터 | d차원 벡터 |
η | 학습률(stepsize) | 스케줄러로 조정 |
B | 미니배치 | |B| = b개 샘플 |
f(x_i; θ_t) | 샘플 x_i의 손실값 | NLL (음의 로그 우도) |
직관적 해설
이 수식의 한계
| 변수 | 의미 | 비고 |
|---|---|---|
w(x_i; θ_t) | 샘플 x_i의 동적 중요도 가중치 | ∑ᵢ w_i = 1 |
η_t | 시간 t에서의 학습률 | 코사인 스케줄 등 적용 |
직관적 해설
수학적 유도
| 변수 | 의미 | 비고 |
|---|---|---|
h_i,t | 정규화된 손실값 | 일반적으로 α=1로 설정 |
f_min, f_max | 현재 배치 내 최소/최대 손실 | 배치마다 재계산 |
직관적 해설
멀티-GPU 고려사항
| 전략 | 특성 | 언제 유리 |
|---|---|---|
LinUpper | 낮은 손실 → 낮은 s. 높은 손실 → s = α (상한 capping) | 일반적 최선 (특히 노이즈가 많은 도메인) |
Quadratic | 중간 손실 강조, 양 극단 모두 하향 | 매우 노이즈 많은 데이터, 이상치 방지 필요 시 |
Extremes | 낮은/높은 손실 모두 강조, 중간 손실 하향 | 이론적으로 일관성 있으나 실험에서 성능 저조 |
LinUpper의 직관
Quadratic의 직관
Extremes의 직관
| 변수 | 의미 | 비고 |
|---|---|---|
r | 온도 파라미터 (temperature) | 학습 중 어닐링(큰값→작은값) |
s_i | 전략 함수 출력값 | 식 (3)에서 계산 |
w_i | 최종 중요도 가중치 | 소프트맥스 정규화, ∑ w_i = 1 |
직관적 해설
왜 어닐링인가?
온도 r에 따른 가중치 분포 변화 (5개 샘플, s값 고정):
s값 = [-0.8, -0.3, 0.1, 0.5, 1.0] (LinUpper 기준). r↑ → 균일에 가까워짐, r↓ → 높은 s 샘플이 거의 전부 차지
이 논문의 중요한 기여 중 하나는 손실 기반 재가중이 수렴 속도에 미치는 영향을 최초로 이론적으로 분석했다는 점입니다. 핵심은 δ_t라는 수렴 편차 항입니다.
| 변수 | 의미 | 비고 |
|---|---|---|
δ_t | 수렴 편차 항 (수렴 속도의 핵심) | 음수이면 표준보다 빠른 수렴 |
θ* | 전역 최적해 | 보간 조건: ∀i에서 f(x_i; θ*)가 최소 |
T̄ | T번 반복의 평균 파라미터 | θ̄ᵀ = (1/T)Σθ_t |
M | 전체 데이터 수 | 균일 가중치 = 1/M |
δ_t 해석 — 핵심 통찰
상한 조건: w_i ≤ 2/M
| 변수 | 의미 | 비고 |
|---|---|---|
L | 리프시츠(Lipschitz) 스무스니스 상수 | Assumption 2 |
δ_t | 수렴 편차 (Theorem 1과 유사) | 현재 파라미터 손실 차이 기반 |
μ_t | 모멘텀 항 수렴 편차 | 이전/현재 파라미터 손실 차이 기반 |
λ_t | 모멘텀 가중치 | λ_t > 0 |
직관적 해설
비볼록 확장 및 수학적 유도 개요 [Appendix C.3]
| 변수 | 의미 | 비고 |
|---|---|---|
C | 정규화 상수 | ∑w_i = 1 보장 |
이 수식의 의미
실천적 의미
아래 스텝 플레이어에서 각 단계를 눌러 알고리즘의 흐름을 따라가보세요.
Input: η, θ₀, {r_t}ₜ, f_min, f_max
for t = 0, 1, ..., T-1:
B = {x_i}_{i=1}^b # (1) 미니배치 샘플링
{f_{i,t}} = forward(B, θ_t) # (2) 손실 계산
h_{i,t} = normalize(f_{i,t}) # (3) 정규화 [−α,α]
s_{i,t} = strategy(h_{i,t}) # (4) 전략 적용
w_{i,t} = softmax(s/r_t) # (5) 커리큘럼 가중치
θ_{t+1} = θ_t - η·Σ w·∇f # 가중 업데이트
# 1. 전략 함수 def apply_strategy(losses, delta=1.0, strategy="linupper"): if strategy == "linupper": return torch.minimum(losses + delta, delta * torch.ones_like(losses)) elif strategy == "quadratic": return 1 - losses**2 / delta**2 elif strategy == "extremes": return torch.abs(losses) # 2. 온도 스케일링 (소프트맥스용) def scale_losses(losses, r): return torch.exp(losses / r) # 3. [-δ, δ] 정규화 def normalize_losses(losses, delta=1., l_min=0., l_max=1.): return 2. * delta * losses / max(l_max - l_min, 1e-6) \ - delta * (l_max + l_min) / max(l_max - l_min, 1e-6) # 4. per-sample NLL 손실 계산 (패딩 토큰 제외) def get_batch_loss_from_logits(logits, labels): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() num_active = (shift_labels != -100).sum(dim=1) loss_fct = torch.nn.CrossEntropyLoss(reduction='none') loss = loss_fct(shift_logits.view(-1, logits.size(-1)), shift_labels.view(-1).long()) return loss.view(logits.size(0), -1).sum(dim=1) / num_active, num_active # 5. FSDP 멀티-GPU 학습 루프 통합 # (다른 GPU 샘플 손실을 all_gather로 수집) gathered_losses = torch.zeros(dist.get_world_size(), len(device_losses), ...) dist.all_gather_into_tensor(gathered_losses, device_losses.detach()) r = r_scheduler(step, cfg.initial_r) with torch.no_grad(): normalized = normalize_losses(gathered_losses.view(-1), ...) reweighted = apply_strategy(normalized, strategy=STRATEGY) scaled = scale_losses(reweighted - reweighted.max(), l=r) weights = scaled / scaled.sum() device_weights = weights.view(dist.get_world_size(), -1)[local_rank, :] loss = torch.sum(device_weights * device_losses) * dist.get_world_size() loss.backward()
보충 * dist.get_world_size()로 스케일링하는 이유: 글로벌 배치 전체의 가중치가 1이 되도록 정규화했기 때문에, 각 GPU 손실의 합이 1/world_size가 됩니다. world_size를 곱해서 전통적인 학습률 스케일을 복원합니다.
| 모델 | 파라미터 | 배치 크기 | 학습률 | LR 스케줄 | 워밍업 | 총 스텝 | r (초기→최종) |
|---|---|---|---|---|---|---|---|
| GPT2-mini | 124M | 32 | 5×10⁻⁴ | Linear Warmup Cosine | 500 | 20,000 | 100→0.4 |
| GPT2-small | 210M | 48 | 5×10⁻⁴ | Linear Warmup Cosine | 500 | 20,000 | 100→0.4 |
| GPT2-medium | 300M | 48 | 5×10⁻⁴ | Linear Warmup Cosine | 500 | 20,000 | 100→0.4 |
| Llama-1.4B | 1.4B | 128 | 3×10⁻⁴ | Linear Warmup Cosine | 2,000 | 100,000 | 100→1.0 |
| Llama-7B | 7B | 128 | 3×10⁻⁴ | Linear Warmup Cosine | 2,000 | 170,000 | 100→2.0 |
공통 설정: Weight decay 0.01, Max Grad Norm 1.0, AdamW 옵티마이저(β1/β2 표준값). GPT2: HuggingFace Trainer + Accelerate. Llama: PyTorch FSDP. 코드: github.com/sowmaster/Sample-Level-Loss-Reweighting-ICLR-2025
| 아키텍처 | 이름 | 레이어 | 어텐션 헤드 | 임베딩 차원 | FFN 차원 | 최대 시퀀스 길이 |
|---|---|---|---|---|---|---|
| GPT-2 | Mini (124M) | 12 | 12 | 768 | 3072 | 512 |
| Small (210M) | 24 | 16 | 786 | 3072 | 512 | |
| Medium (300M) | 36 | 24 | 786 | 3072 | 512 | |
| Llama | 1.4B | 24 | 16 | 2048 | 2048 | 8192 |
| 7B | 32 | 32 | 4096 | 4096 | 8192 |
SlimPajama 7개 도메인 각각에서의 hold-out 검증 퍼플렉시티(낮을수록 좋음). GPT2-medium(300M) 기준.
Book, C4, CC 도메인에서 특히 두드러진 개선. 이 도메인들은 웹에서 수집된 노이즈가 많은 데이터로, LinUpper의 "낮은 손실 하향 조정"이 반복적/중복 콘텐츠를 효과적으로 줄임을 시사합니다.
| 벤치마크 | Baseline Uniform | Ours LinUpper | Quadratic | Extremes |
|---|---|---|---|---|
| LogiQA (논리 추론) | 25.7 | 27.9 | 25.5 | 26.5 |
| LogiQA 2 | 27.5 | 27.6 | 28.6 | 27.7 |
| SciQ (과학 추론) | 49.0 | 51.8 | 48.7 | 49.6 |
| PiQA (물리 상식) | 55.5 | 56.2 | 54.6 | 55.2 |
| 평균 | 39.4 | 40.9 | 39.3 | 39.8 |
LinUpper가 4개 중 3개 태스크에서 최고 성능. Extremes는 모든 태스크에서 다른 방법보다 낮음 → 낮은 손실 샘플을 줄이는 것이 핵심임을 반증.
| 벤치마크 | DoGE | Ours DoGE + LinUpper | DoReMi | Ours DoReMi + LinUpper |
|---|---|---|---|---|
| LogiQA | 27.2 | 28.6 | 27.2 | 27.6 |
| LogiQA 2 | 27.5 | 28.0 | 27.7 | 27.9 |
| SciQ | 52.8 | 53.2 | 53.3 | 54.5 |
| PiQA | 55.8 | 56.3 | 55.7 | 56.1 |
| 평균 | 40.8 | 41.5 | 41.0 | 41.5 |
FineWeb 15T 데이터셋에서 무작위 샘플링한 서브셋으로 학습. 1.4B: ~100B 토큰, 7B: ~175B 토큰.
Llama-1.4B 평균 성능
Llama-7B 평균 성능
| 모델 | 태스크 유형 | Baseline Uniform | Ours LinUpper | 개선 |
|---|---|---|---|---|
| Llama-1.4B | LUR (13개) | 48.51 | 49.16 | +0.66% |
| QA (6개) | 42.35 | 44.07 | +1.72% | |
| 전체 (19개) | 46.56 | 47.55 | +0.99% | |
| Llama-7B | LUR (13개) | 50.54 | 51.98 | +1.44% |
| QA (6개) | 48.15 | 49.31 | +1.16% | |
| 전체 (19개) | 49.79 | 51.14 | +1.35% |
모델이 클수록 개선 효과가 더 두드러집니다 (1.4B: +0.99% → 7B: +1.35%). 작은 모델은 재가중 방법의 세밀한 신호를 충분히 활용할 용량이 부족할 수 있습니다.
r 값이 너무 크면(=1) 균일 가중치와 거의 같아지고, 너무 작으면(=0.2) 낮은 손실 샘플을 과도하게 걸러내어 데이터 낭비가 발생합니다. 최적 r은 모델 규모에 따라 다릅니다(7B에서 r=2.0).
선형 회귀에 25% 이상치 데이터를 섞은 합성 데이터에서도 LinUpper가 가장 빠르게 수렴합니다. 이는 손실 기반 재가중의 이점이 LLM 특화 현상이 아니라 ML 전반에 걸친 근본적 원리임을 보여줍니다.
| 방법 | 이상치 처리 | 수렴 안정성 | 성능 |
|---|---|---|---|
| Uniform (기준) | 없음 | 안정 | 기준 |
| Ours LinUpper | 가중치 상한으로 제어 | 안정 | 최고 |
| DRO-KL (같은 LR) | 최악 케이스 집중 | 발산(diverge) | - |
| DRO-KL (낮은 LR) | 최악 케이스 집중 | 수렴 | LinUpper보다 낮음 |
얻은 것
포기한 것
이 논문은 LLM 사전학습 실무자에게 즉각적인 가치를 제공합니다. 기존 PyTorch/JAX 학습 루프에 몇 줄의 코드만 추가하면 됩니다. 특히 Common Crawl, C4처럼 노이즈가 많은 대규모 웹 코퍼스를 사용하는 팀에 효과적입니다.
보충 이 논문의 기여는 단순히 "새로운 기법"을 넘어, 손실 기반 재가중이 수렴에 미치는 영향의 첫 공식적 이론 분석을 제공한다는 점에서 학문적으로도 중요합니다. 이전 작업(Loshchilov & Hutter 2015; Jiang et al. 2019)은 경험적 관찰에 그쳤습니다.
논문 결론부에서 저자들은 다음 방향을 제안합니다: 이 재가중 메커니즘이 대규모 노이즈 데이터셋에서 광범위한 데이터 큐레이션 필요성을 줄이는 자동화된 온라인 데이터 선택의 토대가 될 수 있다. 또한 적대적 학습(adversarial ML), 도메인 적응, 데이터 증강, 불균형 분류 등 다른 ML 분야에서의 적용도 미래 연구로 남겨두었습니다.
전략 함수 s_i = min{h_i + α, α}는 정규화된 손실 h_i에 대해 선형으로 증가하다가 α에서 상한이 걸립니다. h_i가 α 이상이면 모두 동일하게 α로 캡핑됩니다. "Linear"는 하한 구간에서의 선형성, "Upper-bound"는 α에서의 상한 제약을 의미합니다. 이 상한이 이상치(극단적 고손실 샘플)가 학습을 지배하는 것을 막습니다.
"The functional form is s_i := min{h_i + α, α}."함수 형태는 s_i = min{h_i + α, α}로 정의됩니다.
이것이 이 논문의 가장 직관에 반하는 부분입니다. 핵심은 "버린다"가 아니라 "이 특정 학습 스텝에서의 기여도를 줄인다"는 것입니다. 낮은 손실 샘플은 두 가지를 의미합니다:
① 모델이 이미 배운 것 — 이 스텝에서 이 샘플의 그래디언트 기여는 학습에 이미 잘 반영된 방향입니다. 이를 줄여도 손실이 없습니다.
② 처음부터 단순/반복적인 것 — 웹 크롤 데이터의 많은 부분이 중복이거나 템플릿화된 텍스트입니다. LinUpper는 데이터 큐레이션 없이 이를 자동으로 경시합니다.
평가 다만, "낮은 손실 = 이미 배운 것"이라는 전제는 완전히 정확하지 않을 수 있습니다. 낮은 손실이 쉬운 데이터 때문인지 모델이 이미 학습했기 때문인지 구분하기 어렵고, 이 둘의 처리 방식이 다를 수 있습니다.
이 논문은 사전학습(pretraining)에 집중합니다. 그러나 관련 연구인 Chen et al. (2024)가 연속 학습(continual training)/인스트럭션 튜닝에서 중간 손실 샘플 재가중을 사용했음을 언급합니다.
평가 Fine-tuning에서는 데이터를 여러 에포크 반복하는 경우가 많으므로, 샘플이 한 번만 등장한다는 핵심 제약이 사라집니다. 이 경우 EWC(Elastic Weight Consolidation) 같은 기법이 더 적합할 수 있습니다. 하지만 온도 어닐링과 LinUpper의 핵심 아이디어 자체는 fine-tuning에도 적용 가능해 보입니다.
Table 1을 보면 GitHub 도메인에서 GPT2-mini/small/medium 모두 LinUpper의 퍼플렉시티가 Uniform보다 약간 높습니다. 논문은 이를 명시적으로 설명하지 않습니다.
평가 GitHub 코드는 다른 도메인에 비해 구조적 패턴이 매우 규칙적입니다(들여쓰기, 함수 선언 등). 이런 코드 데이터는 초기 학습 단계에서 손실이 낮은 패턴을 일찍 학습하므로, LinUpper가 이를 지나치게 하향 조정하면 후기 세밀한 학습에 불리할 수 있습니다. "낮은 손실 = 반복적"이라는 가정이 코드 도메인에서는 덜 성립할 수 있습니다.
2/M 상한은 이론 분석(Theorem 1, 2)에서 수렴 보장이 성립하기 위한 조건입니다. 직관적으로는 "어떤 단일 샘플도 균일 가중치의 2배 이상을 받을 수 없다"는 것으로, 극단적 이상치 집중을 막습니다.
"maxi w(xi;θt) ≤ 2/M essentially means that the weight of any sample (after reweighting) should not be more than twice the uniform weight."이 상한은 재가중 후 어떤 샘플의 가중치도 균일 가중치의 2배를 초과하지 않아야 한다는 의미입니다.
Appendix A.5의 실험에서 7B 모델 학습 시 LinUpper의 실제 최대 가중치는 항상 ~0.01 이하로, 이론 상한 0.0156(=2/128)을 만족합니다.
평가 더 큰 상한을 허용하면 이론 보장이 사라지고, 극단적 이상치 집중 리스크가 높아집니다. DRO가 이 상한을 지키지 않아 성능이 나쁜 사례(Appendix B.3)가 이를 잘 보여줍니다.
GPT2 실험 전체 코드: github.com/sowmaster/Sample-Level-Loss-Reweighting-ICLR-2025
Appendix A.1에 PyTorch FSDP 멀티-GPU 학습 루프 통합 코드가 상세히 제공됩니다. 핵심 유틸리티 함수 4개(apply_strategy, scale_losses, normalize_losses, get_batch_loss_from_logits)를 기존 학습 루프에 추가하는 방식입니다.
이 섹션은 논문 원문의 Figure와 Table을 그대로 보존합니다. 독자가 원본 논문의 어느 위치를 찾아야 하는지 확인할 수 있습니다.