Daouda A. Sow · Herbert Woisetschläger · Saikiran Bulusu · Shiqiang Wang · Hans-Arno Jacobsen · Yingbin Liang  ·  ICLR 2025

Dynamic Loss-Based Sample Reweighting for Improved Large Language Model Pretraining

LLM 사전학습에서 모든 샘플을 동등하게 취급하는 관행은 낭비다 — 낮은 손실 샘플을 자동으로 줄여가며 배치마다 중요도를 재조정하면, 추가 비용 없이 더 빠른 수렴과 더 높은 성능을 동시에 얻을 수 있다.

arXiv 2502.06733 →
LLM 사전학습 인스턴스 수준 재가중 커리큘럼 학습 손실 기반 중요도 샘플링 수렴 분석 데이터 효율적 학습

문제의 배경 — 기존 접근법의 한계

현재 LLM 사전학습 패러다임은 크게 두 단계로 이루어집니다. 데이터 큐레이션(휴리스틱 필터, 품질 분류기, 수동 검토)으로 코퍼스를 구성하고, 이어서 균일 샘플링(uniform sampling)으로 학습합니다. 이 중 데이터 큐레이션이 핵심 작업처럼 여겨지지만, 여기에는 두 가지 근본적 한계가 있습니다.

한계 1 정적 선택(Static Selection) — 학습 전 한 번 고정된 데이터셋은 학습이 진행되면서 바뀌는 각 샘플의 중요도를 반영할 수 없습니다. 초기에 유용했던 반복적 데이터도 모델이 패턴을 익힌 뒤에는 불필요해집니다.
한계 2 스케일 불가능성 — 수천억 토큰 코퍼스에서는 인간이 소량조차 직접 검토하기 어렵고, 계속 성장하는 데이터에 수동 큐레이션을 반복적으로 적용하는 것은 비현실적입니다.

이를 보완하려는 그룹 수준 재가중(group-level reweighting) 연구들 — DoReMi(Xie et al., 2023), DoGE(Fan et al., 2023) — 이 등장했습니다. 이 방법들은 도메인 단위로 샘플링 비율을 조정합니다. 그러나 이 역시 한계가 있습니다.

한계 3 거친 단위(Coarse Granularity) — 같은 "Wikipedia" 도메인 안에서도 모델에게 이미 친숙한 문서와 여전히 배울 것이 많은 문서는 전혀 다른 중요도를 가집니다. 도메인 단위 재가중은 이 차이를 포착하지 못합니다.
한계 4 동적 적응 불가 — 학습이 진행됨에 따라 개별 샘플의 중요도는 달라지지만, 기존 방법들은 이를 학습 중에 반영하지 못합니다.

더 나아가, 컴퓨터 비전에서 효과적이었던 인스턴스 수준 재가중 기법들(Loshchilov & Hutter 2015; Jiang et al. 2019)도 LLM 사전학습에는 적합하지 않습니다. 그 이유가 이 논문이 해결하려는 핵심 기술 과제입니다.

LLM 고유 도전 LLM 사전학습에서는 각 샘플이 학습 중 단 한 번만 등장합니다(컴퓨터 비전은 수백 에포크 반복). 따라서 이전 손실 값이나 그래디언트 노름 같은 과거 통계에 의존하는 방법은 사용 불가합니다. 또한 웹 규모 데이터셋에서 샘플별 통계를 저장하는 것도 메모리상 불가능합니다.
"How can we dynamically leverage instance-level information to accelerate training and improve model performance without incurring significant computational overhead while potentially reducing the need for extensive data curation?"

어떻게 하면 막대한 계산 오버헤드 없이, 데이터 큐레이션 필요성도 줄이면서, 인스턴스 수준 정보를 동적으로 활용하여 학습을 가속하고 모델 성능을 개선할 수 있을까?

이 논문의 선택 — 핵심 아이디어와 트레이드오프

이 논문의 핵심 통찰은 간단합니다: 현재 배치 내 각 샘플의 손실값(loss)은 그 샘플의 현재 시점 중요도를 알려주는 즉각적인 신호다. 이미 잘 아는 것(낮은 손실)보다 아직 모르는 것(높은 손실)에 더 집중하면 학습이 빨라진다 — 이미 안다면 그 반복에서 배울 게 없기 때문입니다.

핵심 아이디어 각 학습 스텝에서 현재 배치의 손실값만으로 샘플 중요도 가중치를 즉시 계산한다. 과거 통계 저장 불필요, 추가 모델 학습 불필요, 추가 forward/backward pass 불필요 — 거의 제로에 가까운 추가 연산 비용.

왜 "낮은 손실 샘플 줄이기"인가?

직관: 모델이 이미 잘 예측하는 샘플(낮은 손실)은 두 가지 경우입니다. ① 해당 패턴을 충분히 배웠거나, ② 그 샘플이 너무 단순/반복적이어서 처음부터 배울 것이 없는 경우. 어느 쪽이든 이 스텝에서 해당 샘플의 그래디언트 기여는 낮습니다. 반면 높은 손실 샘플은 아직 모델이 이해하지 못한 패턴을 포함하고 있습니다.

보충 이는 컴퓨터 비전에서의 "Hard Example Mining"과 유사한 직관이지만, LLM에서는 한 번만 보는 데이터 특성상 과거 통계 없이 현재 배치만으로 즉시 결정해야 한다는 결정적 차이가 있습니다.

왜 "너무 높은 손실 샘플"도 과도하게 올리지 않는가?

손실이 극단적으로 높은 샘플은 노이즈나 이상치(outlier)일 가능성이 높습니다. 이런 샘플에 너무 큰 가중치를 주면 훈련이 불안정해집니다. 논문은 이를 이론적으로도 뒷받침합니다: 가중치에 상한선 w_i ≤ 2/M을 두면 수렴 보장이 가능합니다.

트레이드오프: 얻는 것 수렴 속도 향상, 특히 노이즈가 많은 도메인(Common Crawl, C4, Book)에서 퍼플렉시티(perplexity) 개선. 기존 도메인 수준 재가중(DoGE, DoReMi)과 결합 시 추가 성능 향상.
트레이드오프: 포기하는 것 그래디언트가 편향(biased)됩니다. 균일 가중치에서는 미니배치 그래디언트가 불편(unbiased) 추정량이지만, 손실 기반 가중치는 가중치 자체가 손실의 함수이므로 편향이 생깁니다. 이론 분석이 더 복잡해지는 이유입니다.

세 가지 전략의 형태

정규화된 손실값 h ([-α, α] 범위)에 대해 각 전략이 어떤 모양의 가중치를 부여하는지 슬라이더로 확인하세요.

x축: 정규화된 손실값 h (낮을수록 쉬운 샘플). y축: 소프트맥스 후 최종 중요도 가중치. LinUpper는 낮은 손실 샘플의 가중치를 일관되게 낮추는 반면, Extremes는 양쪽 극단을 모두 강조합니다.

방법론

전체 프레임워크 — 완전 온라인 인스턴스 재가중

학습 루프의 각 스텝 t에서 일어나는 일을 블록으로 표현했습니다. 블록을 클릭하면 세부 설명을 볼 수 있습니다.

미니배치 B {x_i}ᵢ, |B|=b Forward Pass {f_i,t} 계산 손실 정규화 h_i ∈ [-α, α] 전략 적용 LinUpper/ Quadratic/Extremes 가중 업데이트 θ_t+1 ← θ_t - η∇̃ 온도 r 어닐링 r: 큰값→작은값 (학습 초기→후기)

핵심 수식

1. 표준 SGD 업데이트 (기준선)

\[ \theta_{t+1} = \theta_t - \frac{\eta}{|B|} \sum_{i \in B} \nabla f(x_i; \theta_t) \tag{1} \]
변수의미비고
θ_t시간 t에서의 모델 파라미터d차원 벡터
η학습률(stepsize)스케줄러로 조정
B미니배치|B| = b개 샘플
f(x_i; θ_t)샘플 x_i의 손실값NLL (음의 로그 우도)

직관적 해설

배치 내 모든 샘플의 그래디언트를 동등하게 평균 내어 파라미터를 업데이트합니다. 각 샘플 기여도 = 1/|B|. 이것이 기준이 되는 "균일 가중치" 방법입니다.

이 수식의 한계

이미 잘 알고 있는 쉬운 샘플과 아직 배울 것이 많은 어려운 샘플이 동등한 비중으로 파라미터 업데이트에 기여합니다. 웹 규모 이종 데이터셋에서 이는 낭비입니다.

2. 가중 SGD 업데이트 (이 논문의 방법)

\[ \theta_{t+1} = \theta_t - \eta_t \sum_{i \in B} w(x_i; \theta_t) \nabla f(x_i; \theta_t) \tag{2} \]
변수의미비고
w(x_i; θ_t)샘플 x_i의 동적 중요도 가중치∑ᵢ w_i = 1
η_t시간 t에서의 학습률코사인 스케줄 등 적용

직관적 해설

가중치 합이 1로 정규화되어 있어, 학습률 스케줄을 변경할 필요가 없습니다. 단지 각 샘플의 그래디언트 기여도를 손실값에 따라 재분배하는 것입니다. 중요도가 높은 샘플 → 더 많은 그래디언트 기여 → 해당 패턴 방향으로 파라미터가 더 크게 움직입니다.

수학적 유도

식 (1)에서 각 샘플에 1/|B| 대신 w_i를 곱하면 됩니다. ∑ w_i = 1 조건이 학습률의 전체 스케일을 보존합니다. 핵심 질문: w_i를 어떻게 정의하느냐? → 이것이 다음 수식들(전략 함수)의 역할입니다.

3. 손실 정규화

\[ h_{i,t} = \frac{2(f_{i,t} - f_{\min})}{f_{\max} - f_{\min}} - 1 \in [-1, 1] \]
변수의미비고
h_i,t정규화된 손실값일반적으로 α=1로 설정
f_min, f_max현재 배치 내 최소/최대 손실배치마다 재계산

직관적 해설

절대 손실값은 모델 크기, 어휘 크기, 학습 단계에 따라 크게 달라집니다. 이를 [-1,1] 범위로 선형 스케일링하면 전략 함수의 의미가 학습 전반에 걸쳐 일관되게 유지됩니다. h=-1은 배치 내 가장 쉬운 샘플, h=+1은 가장 어려운 샘플입니다.

멀티-GPU 고려사항

분산 학습 시 f_min, f_max는 모든 GPU에서 수집(all_gather)한 후 글로벌 배치 기준으로 계산해야 합니다. 그렇지 않으면 GPU마다 다른 정규화 기준이 적용되어 불일치가 발생합니다. (Appendix A.1 코드 참고)

4. 세 가지 재가중 전략

\[ \text{LinUpper:} \quad s_i = \min\{h_i + \alpha,\ \alpha\} \] \[ \text{Quadratic:} \quad s_i = \alpha\!\left(1 - \frac{h_i^2}{\alpha^2}\right) \] \[ \text{Extremes:} \quad s_i = |h_i| \]
전략특성언제 유리
LinUpper낮은 손실 → 낮은 s. 높은 손실 → s = α (상한 capping)일반적 최선 (특히 노이즈가 많은 도메인)
Quadratic중간 손실 강조, 양 극단 모두 하향매우 노이즈 많은 데이터, 이상치 방지 필요 시
Extremes낮은/높은 손실 모두 강조, 중간 손실 하향이론적으로 일관성 있으나 실험에서 성능 저조

LinUpper의 직관

h+α: h=-1(가장 쉬운)이면 s=0, h=0이면 s=α, h=+1(가장 어려운)이면 s=2α → min{2α, α}=α. 즉, 쉬운 샘플은 가중치 0에 수렴, 어려운 샘플은 최대 α까지. 상한 capping이 극단 이상치를 막습니다.

Quadratic의 직관

h=0 (중간)에서 s=α 최대, h=±α (양 극단)에서 s=0. 포물선 형태. 너무 어려운 샘플(잠재적 이상치)도 낮추는 보수적 전략. 그러나 실험에서는 LinUpper보다 전반적으로 성능이 낮습니다.

Extremes의 직관

|h|: 쉬운 것(h≈-1)과 어려운 것(h≈+1) 모두 높은 s, 중간(h≈0)은 낮은 s. 양쪽 극단에서 배운다는 아이디어. 그러나 실험에서 가장 성능이 낮아, 낮은 손실 샘플을 줄이는 것이 핵심임을 반증합니다.

5. 커리큘럼 기반 온도 조정 (최종 가중치)

\[ w_i = \frac{e^{s_i / r}}{\sum_j e^{s_j / r}} \tag{4} \]
변수의미비고
r온도 파라미터 (temperature)학습 중 어닐링(큰값→작은값)
s_i전략 함수 출력값식 (3)에서 계산
w_i최종 중요도 가중치소프트맥스 정규화, ∑ w_i = 1

직관적 해설

소프트맥스와 동일한 형태입니다. r이 크면(예: r=100) 지수 차이가 줄어들어 모든 w_i가 균일에 가까워집니다. r이 작으면(예: r=0.4) 차이가 극대화되어 높은 s_i 샘플이 거의 모든 가중치를 가져갑니다. 학습 초기에 r=100으로 시작해 후기에 r=0.4~1.0으로 줄이는 어닐링 전략을 사용합니다.

왜 어닐링인가?

학습 초기에는 손실값이 아직 불안정하고 모든 샘플이 미지의 상태입니다. 이때 과도한 재가중은 오히려 다양성을 해칩니다. 학습이 진행되면서 손실 분포가 정보를 담기 시작하면 재가중 효과를 점점 강하게 합니다.

온도 r에 따른 가중치 분포 변화 (5개 샘플, s값 고정):

s값 = [-0.8, -0.3, 0.1, 0.5, 1.0] (LinUpper 기준). r↑ → 균일에 가까워짐, r↓ → 높은 s 샘플이 거의 전부 차지

이론적 프레임워크 — 수렴 분석

이 논문의 중요한 기여 중 하나는 손실 기반 재가중이 수렴 속도에 미치는 영향을 최초로 이론적으로 분석했다는 점입니다. 핵심은 δ_t라는 수렴 편차 항입니다.

Theorem 1 — 볼록 함수, 전체 그래디언트 하강

\[ \frac{1}{M}\sum_{i=1}^M f(x_i; \bar\theta^T) - \frac{1}{M}\sum_{i=1}^M f(x_i; \theta^*) \leq O\!\left(\frac{\|\theta_0 - \theta^*\|^2}{T}\right) + \frac{1}{T}\sum_{t=0}^{T-1}\delta_t \tag{3} \] \[ \text{where}\quad \delta_t = \sum_{i=1}^M \left(\frac{1}{M} - w(x_i;\theta_t)\right)\!\bigl(f(x_i;\theta_t) - f(x_i;\theta^*)\bigr) \]
변수의미비고
δ_t수렴 편차 항 (수렴 속도의 핵심)음수이면 표준보다 빠른 수렴
θ*전역 최적해보간 조건: ∀i에서 f(x_i; θ*)가 최소
T번 반복의 평균 파라미터θ̄ᵀ = (1/T)Σθ_t
M전체 데이터 수균일 가중치 = 1/M

δ_t 해석 — 핵심 통찰

δ_t = 0: w_i = 1/M (균일 가중치) → 전통적 수렴 보장 그대로
δ_t > 0: 낮은 손실 차이(f-f*)를 가진 샘플에 높은 가중치 → 수렴 보장이 느슨해짐
δ_t < 0: 낮은 손실 샘플에 낮은 가중치(= 높은 손실 샘플에 높은 가중치) → 더 빠른 수렴!

즉: 낮은 손실 샘플의 가중치를 줄이면 이론적으로도 더 빠른 수렴이 보장됩니다.

상한 조건: w_i ≤ 2/M

식 (3)의 성립 조건: 어떤 단일 샘플도 균일 가중치의 2배를 초과해서는 안 됩니다. 이 조건이 DRO(분포적 강건 최적화) 같은 worst-case 방법을 배제하는 이유입니다 — DRO는 이 상한을 위반하며 이상치에 과적합됩니다. LinUpper 방법은 실험에서 이 조건을 항상 만족함이 확인됩니다 (7B 모델 실험에서 최대 가중치 ≈ 0.01 < 2/128 ≈ 0.0156).

Theorem 2 — 볼록 함수, 미니배치 SGD with Momentum

\[ \mathbb{E}[f(\theta^T) - f(\theta^*)] \leq \frac{8L\|\theta_0 - \theta^*\|^2}{\sqrt{T+1}} + \frac{2}{T+1}\sum_{t=0}^{T-1}\delta_t + \frac{2}{T+1}\sum_{t=0}^{T-1}\lambda_t\mu_t \tag{10} \]
변수의미비고
L리프시츠(Lipschitz) 스무스니스 상수Assumption 2
δ_t수렴 편차 (Theorem 1과 유사)현재 파라미터 손실 차이 기반
μ_t모멘텀 항 수렴 편차이전/현재 파라미터 손실 차이 기반
λ_t모멘텀 가중치λ_t > 0

직관적 해설

실제 LLM 학습에 사용되는 Adam/SGD+momentum에도 동일한 결론이 적용됩니다. 첫 항은 초기화-최적해 거리에 비례하는 전통적 항, 두 번째와 세 번째 항이 재가중 효과를 담습니다. 낮은 손실 샘플 가중치를 줄이면 δ_t가 음수가 되어 전체 수렴 보장이 개선됩니다.

비볼록 확장 및 수학적 유도 개요 [Appendix C.3]

비볼록 손실(실제 신경망 대부분)에 대한 분석은 Appendix C.4에 있습니다. 이 경우에도 동일한 결론이 성립합니다.

[Appendix 부연 — C.3 핵심 단계]
① ‖θ_{t+1}-θ*‖² 전개: 볼록성(Assumption 1)으로 ⟨∇f_i(θ_t), θ_t-θ*⟩ ≥ f_i(θ_t)-f_i(θ*) 적용
② L-스무스니스(Assumption 2)로 그래디언트 노름 상한: ‖∇f_i(θ_t)‖² ≤ 4L(f_i(θ_t)-f_i(θ*)) + 2σ*²
③ w_max,t ≤ 2/b 조건 하에 η = 1/(8L√(T+1))으로 설정
④ t=0,...,T-1에 걸쳐 텔레스코핑 → 식 (10)의 최종 수렴 한계 도출

Proposition 1 — LinUpper = KL-정규화 최적 전략

\[ w_i = C \cdot \min\!\left\{\exp\!\left(\frac{h_i}{r}\right),\ \frac{2}{M}\right\} \tag{5} \]
변수의미비고
C정규화 상수∑w_i = 1 보장

이 수식의 의미

Proposition 1은 "KL-발산 정규화된 δ_t를 최소화하는 최적 가중치 전략"이 LinUpper(온도 조정 포함)과 같은 형태임을 보입니다. KL 정규화가 없다면 최적해는 상위 M/2개 샘플에만 2/M 가중치, 나머지는 0이 됩니다 — 이는 데이터의 절반을 버리는 극단적 방법이며 실제로 성능이 좋지 않습니다. KL 정규화가 이 극단을 막고 데이터 다양성을 유지합니다.

실천적 의미

LinUpper가 단순히 경험적으로 잘 작동하는 게 아니라, 이론적으로 최적임이 증명됩니다. 이것이 "왜 LinUpper인가?"라는 질문에 대한 완전한 답입니다.

알고리즘 1 — 완전 온라인 인스턴스 재가중

아래 스텝 플레이어에서 각 단계를 눌러 알고리즘의 흐름을 따라가보세요.

스텝 1 / 5
Input: η, θ₀, {r_t}ₜ, f_min, f_max

for t = 0, 1, ..., T-1:
  B = {x_i}_{i=1}^b          # (1) 미니배치 샘플링
  {f_{i,t}} = forward(B, θ_t)  # (2) 손실 계산
  h_{i,t} = normalize(f_{i,t}) # (3) 정규화 [−α,α]
  s_{i,t} = strategy(h_{i,t})  # (4) 전략 적용
  w_{i,t} = softmax(s/r_t)     # (5) 커리큘럼 가중치
  θ_{t+1} = θ_t - η·Σ w·∇f    # 가중 업데이트

구현 세부사항

PyTorch 구현 핵심 코드 [Appendix A.1]

# 1. 전략 함수
def apply_strategy(losses, delta=1.0, strategy="linupper"):
    if strategy == "linupper":
        return torch.minimum(losses + delta, delta * torch.ones_like(losses))
    elif strategy == "quadratic":
        return 1 - losses**2 / delta**2
    elif strategy == "extremes":
        return torch.abs(losses)

# 2. 온도 스케일링 (소프트맥스용)
def scale_losses(losses, r):
    return torch.exp(losses / r)

# 3. [-δ, δ] 정규화
def normalize_losses(losses, delta=1., l_min=0., l_max=1.):
    return 2. * delta * losses / max(l_max - l_min, 1e-6) \
           - delta * (l_max + l_min) / max(l_max - l_min, 1e-6)

# 4. per-sample NLL 손실 계산 (패딩 토큰 제외)
def get_batch_loss_from_logits(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    num_active = (shift_labels != -100).sum(dim=1)
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    loss = loss_fct(shift_logits.view(-1, logits.size(-1)),
                    shift_labels.view(-1).long())
    return loss.view(logits.size(0), -1).sum(dim=1) / num_active, num_active

# 5. FSDP 멀티-GPU 학습 루프 통합
# (다른 GPU 샘플 손실을 all_gather로 수집)
gathered_losses = torch.zeros(dist.get_world_size(), len(device_losses), ...)
dist.all_gather_into_tensor(gathered_losses, device_losses.detach())

r = r_scheduler(step, cfg.initial_r)
with torch.no_grad():
    normalized = normalize_losses(gathered_losses.view(-1), ...)
    reweighted = apply_strategy(normalized, strategy=STRATEGY)
    scaled = scale_losses(reweighted - reweighted.max(), l=r)
    weights = scaled / scaled.sum()
    device_weights = weights.view(dist.get_world_size(), -1)[local_rank, :]

loss = torch.sum(device_weights * device_losses) * dist.get_world_size()
loss.backward()

보충 * dist.get_world_size()로 스케일링하는 이유: 글로벌 배치 전체의 가중치가 1이 되도록 정규화했기 때문에, 각 GPU 손실의 합이 1/world_size가 됩니다. world_size를 곱해서 전통적인 학습률 스케일을 복원합니다.

학습 하이퍼파라미터 [Appendix A.2]

모델파라미터배치 크기학습률LR 스케줄워밍업총 스텝r (초기→최종)
GPT2-mini124M325×10⁻⁴Linear Warmup Cosine50020,000100→0.4
GPT2-small210M485×10⁻⁴Linear Warmup Cosine50020,000100→0.4
GPT2-medium300M485×10⁻⁴Linear Warmup Cosine50020,000100→0.4
Llama-1.4B1.4B1283×10⁻⁴Linear Warmup Cosine2,000100,000100→1.0
Llama-7B7B1283×10⁻⁴Linear Warmup Cosine2,000170,000100→2.0

공통 설정: Weight decay 0.01, Max Grad Norm 1.0, AdamW 옵티마이저(β1/β2 표준값). GPT2: HuggingFace Trainer + Accelerate. Llama: PyTorch FSDP. 코드: github.com/sowmaster/Sample-Level-Loss-Reweighting-ICLR-2025

모델 아키텍처 [Appendix A.2, Table 6]

아키텍처이름레이어어텐션 헤드임베딩 차원FFN 차원최대 시퀀스 길이
GPT-2Mini (124M)12127683072512
Small (210M)24167863072512
Medium (300M)36247863072512
Llama1.4B2416204820488192
7B3232409640968192

결과 — 수치 비교 & 시각화

+1.35%
7B 모델 평균 성능
19개 태스크 전체 평균
Uniform 대비
5/7
도메인 퍼플렉시티 개선
GPT2-medium 기준
LinUpper vs Uniform
≈0
추가 계산 비용
손실 기반 재가중
extra overhead
+1.72%
1.4B QA 태스크
6개 QA 벤치마크 평균
Uniform 대비

GPT2 모델 — 균일 도메인 샘플링 하 퍼플렉시티

SlimPajama 7개 도메인 각각에서의 hold-out 검증 퍼플렉시티(낮을수록 좋음). GPT2-medium(300M) 기준.

Book, C4, CC 도메인에서 특히 두드러진 개선. 이 도메인들은 웹에서 수집된 노이즈가 많은 데이터로, LinUpper의 "낮은 손실 하향 조정"이 반복적/중복 콘텐츠를 효과적으로 줄임을 시사합니다.

GPT2 모델 — 5-shot 추론 벤치마크 (GPT2-medium)

벤치마크Baseline UniformOurs LinUpperQuadraticExtremes
LogiQA (논리 추론)25.727.925.526.5
LogiQA 227.527.628.627.7
SciQ (과학 추론)49.051.848.749.6
PiQA (물리 상식)55.556.254.655.2
평균39.440.939.339.8

LinUpper가 4개 중 3개 태스크에서 최고 성능. Extremes는 모든 태스크에서 다른 방법보다 낮음 → 낮은 손실 샘플을 줄이는 것이 핵심임을 반증.

비균일 도메인 샘플링 — DoGE/DoReMi와의 결합 (GPT2-medium)

벤치마크DoGEOurs DoGE + LinUpperDoReMiOurs DoReMi + LinUpper
LogiQA27.228.627.227.6
LogiQA 227.528.027.727.9
SciQ52.853.253.354.5
PiQA55.856.355.756.1
평균40.841.541.041.5
시너지 효과 인스턴스 수준 재가중(LinUpper)은 도메인 수준 재가중(DoGE, DoReMi)과 직교(orthogonal)한 정보를 활용합니다. 결합 시 두 방법이 각각 포착하지 못하는 정보를 서로 보완하여 일관된 성능 개선을 보입니다.

대규모 모델 — Llama 1.4B 및 7B 결과

FineWeb 15T 데이터셋에서 무작위 샘플링한 서브셋으로 학습. 1.4B: ~100B 토큰, 7B: ~175B 토큰.

Llama-1.4B 평균 성능

Llama-7B 평균 성능

모델태스크 유형Baseline UniformOurs LinUpper개선
Llama-1.4BLUR (13개)48.5149.16+0.66%
QA (6개)42.3544.07+1.72%
전체 (19개)46.5647.55+0.99%
Llama-7BLUR (13개)50.5451.98+1.44%
QA (6개)48.1549.31+1.16%
전체 (19개)49.7951.14+1.35%

모델이 클수록 개선 효과가 더 두드러집니다 (1.4B: +0.99% → 7B: +1.35%). 작은 모델은 재가중 방법의 세밀한 신호를 충분히 활용할 용량이 부족할 수 있습니다.

온도 r 민감도 분석 [Appendix A.4]

r 값이 너무 크면(=1) 균일 가중치와 거의 같아지고, 너무 작으면(=0.2) 낮은 손실 샘플을 과도하게 걸러내어 데이터 낭비가 발생합니다. 최적 r은 모델 규모에 따라 다릅니다(7B에서 r=2.0).

장난감 회귀 문제 실험 [Appendix B]

선형 회귀에 25% 이상치 데이터를 섞은 합성 데이터에서도 LinUpper가 가장 빠르게 수렴합니다. 이는 손실 기반 재가중의 이점이 LLM 특화 현상이 아니라 ML 전반에 걸친 근본적 원리임을 보여줍니다.

DRO (분포적 강건 최적화)와의 비교 [Appendix B.3]

방법이상치 처리수렴 안정성성능
Uniform (기준)없음안정기준
Ours LinUpper가중치 상한으로 제어안정최고
DRO-KL (같은 LR)최악 케이스 집중발산(diverge)-
DRO-KL (낮은 LR)최악 케이스 집중수렴LinUpper보다 낮음

한계점 & 트레이드오프

한계 1 소형 모델에서의 미미한 효과 — GPT2-mini(124M), GPT2-small(210M)에서는 개선 폭이 작습니다. 소형 모델은 재가중이 만들어내는 미세한 그래디언트 신호를 충분히 활용할 표현 능력이 부족한 것으로 추측됩니다.
한계 2 편향된 그래디언트(Biased Gradients) — 손실 기반 가중치는 가중치 자체가 손실의 함수이므로 미니배치 그래디언트가 전체 그래디언트의 불편(unbiased) 추정량이 아닙니다. 이론 분석에서 추가 항(δ_t, μ_t)이 생기는 이유입니다. 실제 수렴에 지장은 없으나 이론적 분석이 더 복잡해집니다.
한계 3 r 하이퍼파라미터 민감도 — 최적 온도 r은 모델 크기에 따라 다릅니다(GPT2: r=0.4, Llama-7B: r=2.0). 새로운 모델 규모에 대해 r을 어닐링하는 최적 일정(schedule)을 결정하려면 별도의 탐색이 필요합니다.
한계 4 분산 학습 통신 오버헤드 — 멀티-GPU 환경에서 글로벌 f_min/f_max를 구하려면 all_gather 통신이 필요합니다. 소규모 배치 크기나 많은 GPU 수에서는 이 통신 비용이 무시할 수 없을 수도 있습니다.
한계 5 커리큘럼 학습과의 명시적 통합 미탐색 — 저자들은 이 방법이 암묵적으로 커리큘럼 학습 효과를 가진다고 언급하지만, 명시적 커리큘럼 전략(예: 학습 단계별 도메인 비율 동적 조정)과의 결합은 이 논문에서 탐색되지 않았습니다.

얻은 것 vs. 포기한 것 요약

얻은 것

  • 수렴 속도 향상 (노이즈 많은 도메인에서 특히)
  • 추가 계산 비용 거의 없음
  • 기존 방법(DoGE, DoReMi)과 결합 가능
  • 이론적 수렴 보장 포함
  • 이상치에 강인(상한 조건)

포기한 것

  • 그래디언트 불편성(unbiasedness)
  • 소형 모델에서의 큰 효과 기대 어려움
  • 추가 하이퍼파라미터 r 튜닝 필요
  • 멀티-GPU 통신 약간 증가

영향력 & 후속 연구

누구에게 도움이 되는가?

이 논문은 LLM 사전학습 실무자에게 즉각적인 가치를 제공합니다. 기존 PyTorch/JAX 학습 루프에 몇 줄의 코드만 추가하면 됩니다. 특히 Common Crawl, C4처럼 노이즈가 많은 대규모 웹 코퍼스를 사용하는 팀에 효과적입니다.

보충 이 논문의 기여는 단순히 "새로운 기법"을 넘어, 손실 기반 재가중이 수렴에 미치는 영향의 첫 공식적 이론 분석을 제공한다는 점에서 학문적으로도 중요합니다. 이전 작업(Loshchilov & Hutter 2015; Jiang et al. 2019)은 경험적 관찰에 그쳤습니다.

저자가 제안한 후속 과제

논문 결론부에서 저자들은 다음 방향을 제안합니다: 이 재가중 메커니즘이 대규모 노이즈 데이터셋에서 광범위한 데이터 큐레이션 필요성을 줄이는 자동화된 온라인 데이터 선택의 토대가 될 수 있다. 또한 적대적 학습(adversarial ML), 도메인 적응, 데이터 증강, 불균형 분류 등 다른 ML 분야에서의 적용도 미래 연구로 남겨두었습니다.

관련 분야 더 탐색하기

Q&A — 연구자의 고민과 독자의 질문

Q. LinUpper가 왜 "선형 상한(linear upper-bound)"인가?

전략 함수 s_i = min{h_i + α, α}는 정규화된 손실 h_i에 대해 선형으로 증가하다가 α에서 상한이 걸립니다. h_i가 α 이상이면 모두 동일하게 α로 캡핑됩니다. "Linear"는 하한 구간에서의 선형성, "Upper-bound"는 α에서의 상한 제약을 의미합니다. 이 상한이 이상치(극단적 고손실 샘플)가 학습을 지배하는 것을 막습니다.

"The functional form is s_i := min{h_i + α, α}."

함수 형태는 s_i = min{h_i + α, α}로 정의됩니다.

Q. "낮은 손실 샘플을 줄인다"는 게 좋은 데이터를 버리는 게 아닌가?

이것이 이 논문의 가장 직관에 반하는 부분입니다. 핵심은 "버린다"가 아니라 "이 특정 학습 스텝에서의 기여도를 줄인다"는 것입니다. 낮은 손실 샘플은 두 가지를 의미합니다:

① 모델이 이미 배운 것 — 이 스텝에서 이 샘플의 그래디언트 기여는 학습에 이미 잘 반영된 방향입니다. 이를 줄여도 손실이 없습니다.

② 처음부터 단순/반복적인 것 — 웹 크롤 데이터의 많은 부분이 중복이거나 템플릿화된 텍스트입니다. LinUpper는 데이터 큐레이션 없이 이를 자동으로 경시합니다.

평가 다만, "낮은 손실 = 이미 배운 것"이라는 전제는 완전히 정확하지 않을 수 있습니다. 낮은 손실이 쉬운 데이터 때문인지 모델이 이미 학습했기 때문인지 구분하기 어렵고, 이 둘의 처리 방식이 다를 수 있습니다.

Q. 이 방법을 fine-tuning이나 instruction tuning에도 사용할 수 있는가?

이 논문은 사전학습(pretraining)에 집중합니다. 그러나 관련 연구인 Chen et al. (2024)가 연속 학습(continual training)/인스트럭션 튜닝에서 중간 손실 샘플 재가중을 사용했음을 언급합니다.

평가 Fine-tuning에서는 데이터를 여러 에포크 반복하는 경우가 많으므로, 샘플이 한 번만 등장한다는 핵심 제약이 사라집니다. 이 경우 EWC(Elastic Weight Consolidation) 같은 기법이 더 적합할 수 있습니다. 하지만 온도 어닐링과 LinUpper의 핵심 아이디어 자체는 fine-tuning에도 적용 가능해 보입니다.

Q. 왜 GitHub 도메인에서는 LinUpper가 Uniform보다 퍼플렉시티가 높은가?

Table 1을 보면 GitHub 도메인에서 GPT2-mini/small/medium 모두 LinUpper의 퍼플렉시티가 Uniform보다 약간 높습니다. 논문은 이를 명시적으로 설명하지 않습니다.

평가 GitHub 코드는 다른 도메인에 비해 구조적 패턴이 매우 규칙적입니다(들여쓰기, 함수 선언 등). 이런 코드 데이터는 초기 학습 단계에서 손실이 낮은 패턴을 일찍 학습하므로, LinUpper가 이를 지나치게 하향 조정하면 후기 세밀한 학습에 불리할 수 있습니다. "낮은 손실 = 반복적"이라는 가정이 코드 도메인에서는 덜 성립할 수 있습니다.

Q. 왜 w_i ≤ 2/M 상한이 필요한가? 더 큰 상한은 어떤가?

2/M 상한은 이론 분석(Theorem 1, 2)에서 수렴 보장이 성립하기 위한 조건입니다. 직관적으로는 "어떤 단일 샘플도 균일 가중치의 2배 이상을 받을 수 없다"는 것으로, 극단적 이상치 집중을 막습니다.

"maxi w(xi;θt) ≤ 2/M essentially means that the weight of any sample (after reweighting) should not be more than twice the uniform weight."

이 상한은 재가중 후 어떤 샘플의 가중치도 균일 가중치의 2배를 초과하지 않아야 한다는 의미입니다.

Appendix A.5의 실험에서 7B 모델 학습 시 LinUpper의 실제 최대 가중치는 항상 ~0.01 이하로, 이론 상한 0.0156(=2/128)을 만족합니다.

평가 더 큰 상한을 허용하면 이론 보장이 사라지고, 극단적 이상치 집중 리스크가 높아집니다. DRO가 이 상한을 지키지 않아 성능이 나쁜 사례(Appendix B.3)가 이를 잘 보여줍니다.

Q. 구현 코드는 어디서 볼 수 있는가?

GPT2 실험 전체 코드: github.com/sowmaster/Sample-Level-Loss-Reweighting-ICLR-2025

Appendix A.1에 PyTorch FSDP 멀티-GPU 학습 루프 통합 코드가 상세히 제공됩니다. 핵심 유틸리티 함수 4개(apply_strategy, scale_losses, normalize_losses, get_batch_loss_from_logits)를 기존 학습 루프에 추가하는 방식입니다.

원본 논문 자료 — Figure & Table

이 섹션은 논문 원문의 Figure와 Table을 그대로 보존합니다. 독자가 원본 논문의 어느 위치를 찾아야 하는지 확인할 수 있습니다.

Figure 1 — 재가중 전략 곡선
Figure 1 (논문 §4.2, p.5): 재가중 전략들의 기하학적 곡선(왼쪽)과 LinUpper에 온도 r을 적용했을 때의 형태 변화(오른쪽). r이 커질수록 균일 가중치에 수렴함을 보여줍니다. 배치 크기 128의 균일 샘플 손실에서 계산된 결과입니다. 이 Figure는 본 페이지의 인터랙티브 슬라이더로도 확인할 수 있습니다.
Figure 2 — GPT2-medium 도메인별 퍼플렉시티
Figure 2 (논문 §6.2, p.8): 균일 도메인 샘플링 설정에서 GPT2-medium 모델의 도메인별 hold-out 검증 퍼플렉시티. LinUpper 전략은 7개 도메인 중 5개에서 더 낮거나 동등한 퍼플렉시티를 달성합니다. 특히 CC, C4, Book에서 뚜렷한 개선이 보입니다.
Figure 2 (계속) 및 Table 1, 2
Figure 2 (계속) + Table 1 & 2 (논문 §6.2, p.8-9): GitHub, StackExchange, Wikipedia 도메인 퍼플렉시티와 GPT2 모델별 최종 퍼플렉시티 수치(Table 1), 5-shot 추론 벤치마크 결과(Table 2). LinUpper가 4개 중 3개 태스크에서 최고 성능.