처음부터 완전 양방향 어텐션으로 학습된 8B 마스크 디퓨전 언어 모델이다. 자기회귀 인수분해와 인과 어텐션이 지배하는 현 패러다임과 달리, 사전학습부터 SFT까지 마스크 디퓨전 목적함수를 일관되게 유지하며 강력한 언어 모델로 가는 경쟁력 있는 경로를 보인다.
현재 대규모 언어 모델은 자기회귀 패러다임이 지배한다. 최근 디퓨전 언어 모델은 언어 생성에 대한 다른 접근으로 주목받고 있으며, LLaDA는 마스크 디퓨전 정식화를 따라 완전 양방향 어텐션으로 언어 모델을 처음부터 학습했다.
LLaDA는 비자기회귀 모델도 문맥 내 학습, 지시 따르기 같은 핵심 LLM 능력을 획득할 수 있음을 보여, 언어 지능이 반드시 자기회귀 모델링에 의존해야 한다는 통념에 도전했다. 이 개념적 함의를 넘어, 양방향 디퓨전 언어 모델은 역방향·양방향 추론, 장기 지평 계획, 멀티모달·옴니 모델링에서 이점을 보였다. 최근 연구는 양방향 디퓨전 사전학습이 반복 학습하에서 제한된 데이터를 더 잘 활용해, 데이터 제약 환경에서 자기회귀 모델을 능가하게 함을 보였다.
그러나 LLaDA는 여전히 초기 대규모 시도였고, Qwen2·Qwen2.5 같은 강력한 자기회귀 모델에는 미치지 못해 개선 여지가 상당히 남아 있었다. iLLaDA(improved LLaDA)는 8B 완전 양방향 마스크 디퓨전 모델을 처음부터 학습해 이 격차를 좁힌다.
좌에서 우로 한 토큰씩 순차 생성하며, 각 토큰은 이전 토큰만 참조한다. 현 LLM의 주류 패러다임이다.
전 위치를 동시에 보며 마스크된 토큰을 신뢰도 순으로 드러낸다. 사전학습부터 SFT, 추론까지 동일한 마스크 디퓨전 목적을 유지한다.
코퍼스를 12T로 확장하고, 그룹 쿼리 어텐션(GQA)으로 캐시형 추론 메모리를 줄이며, 입출력 임베딩을 묶어 파라미터를 절감한다. 대규모 학습용 학습률 스케줄을 수정한다.
가변 길이 생성을 위해 SFT 전략을 수정하고, 25B 토큰 지시 코퍼스로 12 에폭 미세조정한다.
효율을 위한 가변 길이 생성을 사용한다.
객관식 벤치마크에 신뢰도 기반 채점을 도입해 우도 상계보다 나은 성능을 얻는다.
이 변경들은 LLaDA를 실질적으로 개선한다. 처음부터 학습한 LLaDA와 Qwen2.5에서 미세조정한 Dream을 포함한 기존 양방향 디퓨전 언어 모델 대비, iLLaDA는 베이스·지시 평가 모두에서 최고 평균 성능을 얻는다. Qwen2.5 7B와 비교하면 iLLaDA-Base는 평균에서 약간 더 강하고, iLLaDA-Instruct는 아직 뒤처진다. 절제 연구는 신뢰도 기반 채점이 객관식 평가를 개선하고, iLLaDA가 다중 에폭 SFT에서 계속 이득을 얻음을 보인다.
LLaDA의 마스크 디퓨전 정식화를 유지하되, 확장·사후학습·평가에 중요한 몇 가지 실용적 변경을 가한다.
iLLaDA는 LLaDA와 동일한 사전학습 목적을 따른다. 길이 L의 깨끗한 시퀀스 x₀가 주어지면 마스킹 비율 t~U[0,1]를 샘플링하고, 각 토큰을 확률 t로 독립적으로 마스크 토큰 M으로 대체해 손상 시퀀스 xₜ를 얻는다. 모델은 마스크된 모든 토큰을 예측하도록 학습된다.
백본은 RMSNorm, SwiGLU, RoPE를 쓰고 어텐션·MLP 바이어스가 없는 밀집 트랜스포머다. 멀티헤드 어텐션을 쓰는 LLaDA와 달리 iLLaDA는 GQA를 사용한다. KV 캐시류 메커니즘이 디퓨전 언어 모델에도 적용될 수 있음이 최근 밝혀졌고, 이런 캐시형 구현하에서 GQA는 캐시된 키/값 상태의 메모리 사용량을 줄인다. 파라미터 수를 더 통제하기 위해 입력 임베딩과 LM 헤드 파라미터를 묶는다.
최대 시퀀스 길이 8192로 사전학습하며, 마스크 디퓨전의 무작위 길이 학습에서 영감을 받아 30% 확률로 8192 시퀀스를 두 짧은 세그먼트로 무작위 분할한다. 가변 길이 예시를 배치마다 패킹하고 FlashAttention 기반 가변 길이 어텐션 커널로 패딩 없이 예시를 분리한다. 학습률은 2×10⁻⁴까지 선형 워밍업 후 일정하게 유지하다가, 사전학습 손실이 더 줄지 않을 때 최소 5×10⁻⁶의 코사인 감쇠로 전환하니 손실이 다시 개선됐다. 가중치 감쇠 0.1의 AdamW를 쓴다.
| 항목 | iLLaDA 8B | LLaDA 8B |
|---|---|---|
| 레이어 / 모델 차원 | 32 / 4096 | 32 / 4096 |
| 어텐션 헤드 | 32 | 32 |
| Key/Value 헤드 | 8 | 32 |
| FFN 차원 | 14,336 | 12,288 |
| 어휘 크기 | 155,136 | 126,464 |
| 최대 시퀀스 길이 | 8192 | 4096 |
| 임베딩 / LM 헤드 | Tied (묶음) | Untied |
| 총 파라미터 | 7.62B | 8.02B |
| 비임베딩 파라미터 | 6.98B | 6.98B |
선행 연구는 보통 프롬프트와 전체 참조 응답을 이어붙여 SFT 인스턴스를 구성하고, 학습 중 프롬프트 토큰은 가시 상태로 두고 응답 영역에만 마스크를 적용했다. 미니배치 내 짧은 응답은 가장 긴 응답 길이에 맞춰 |EOS|로 패딩됐다.
iLLaDA는 대신 사전학습과 동일한 데이터 처리·마스킹 방식을 쓴다. 각 지시 예시를 프롬프트-응답 시퀀스와 단일 종단 |EOS|로 포맷하고, 모든 예시를 연속 지시 코퍼스로 이어붙인 뒤 8192 토큰 학습 시퀀스를 샘플링한다. 그다음 전체 시퀀스에 무작위 마스크를 적용해 식 (1)을 최적화하므로, 프롬프트·응답·|EOS| 토큰이 모두 마스크될 수 있다. 사전학습과 같은 무작위 길이 학습도 쓴다. 이 SFT 포맷은 가변 길이 블록 생성을 자연스럽게 지원한다.
SFT 코퍼스는 약 250억 토큰이며 12 에폭 미세조정한다. 학습률은 5×10⁻⁶까지 선형 워밍업 후 일정하게 유지하다가 마지막 10% 구간에서 5×10⁻⁷까지 선형 감쇠한다.
iLLaDA는 LLaDA와 동일한 확률적 정식화를 쓴다. 두 모델 모두 식 (1)의 마스크 디퓨전 목적으로 학습되며, 이는 모델 분포의 음의 로그우도 상계에 해당한다.
많은 벤치마크가 HellaSwag, PIQA, ARC-Challenge처럼 객관식으로 정식화된다. 접두 p와 유한한 후보 연속들이 주어지면, 평가는 각 후보에 점수를 매겨 최고점을 고른다. iLLaDA는 결정론적 신뢰도 기반 채점 규칙을 쓰며, 이는 객관식에서 로그우도 상계보다 경험적으로 더 낫다. 전부 마스크된 후보에서 시작해, 남은 마스크 위치 중 모델이 가장 높은 신뢰도를 부여하는 토큰을 매 단계 하나씩 정답 토큰으로 드러낸다.
프롬프트에 마스크 토큰 블록을 덧붙이고 그 블록 안에서 디퓨전 샘플러를 돌린다. 각 샘플링 단계에서 모델은 모든 마스크 위치를 예측하고, MaskGIT·LLaDA의 저신뢰 리마스킹 전략을 따라 가장 확신하는 예측을 가시 토큰으로 전환하되 저신뢰 위치는 마스크로 유지한다. 블록이 디코딩되면 |EOS|나 정지 토큰이 나타날 때 종료하고, 아니면 새 마스크 블록을 덧붙여 최대 생성 예산까지 계속한다.
베이스·지시 평가 모두에서 iLLaDA는 기존 디퓨전 언어 모델을 실질적으로 개선하며, 여러 추론 벤치마크에서 강력한 자기회귀 베이스라인과 경쟁력을 유지한다.
| 벤치마크 | iLLaDA 8B | LLaDA 8B | Dream 7B | Qwen2.5 7B |
|---|---|---|---|---|
| 방식 / 학습 토큰 | Diff · 12T | Diff · 2.3T | Diff · 18T+0.6T | AR · 18T |
| General Tasks | ||||
| MMLU | 74.8 | 65.9 | 69.5 | 71.9 |
| BBH | 71.3 | 49.7 | 57.9 | 63.9 |
| ARC-C | 60.8 | 45.9 | 59.8 | 51.5 |
| HellaSwag | 76.6 | 70.5 | 73.3 | 79.0 |
| Mathematics & Science | ||||
| GSM8K | 81.9 | 70.3 | 77.2 | 78.9 |
| MATH | 38.4 | 31.4 | 39.6 | 41.1 |
| Code | ||||
| HumanEval | 50.0 | 35.4 | 57.9 | 56.7 |
| MBPP | 57.8 | 40.0 | 56.2 | 63.6 |
| 평균 | 63.9 | 51.1 | 61.4 | 63.3 |
iLLaDA는 모든 과제에서 LLaDA를 크게 개선하며 BBH·ARC-C·GSM8K·HumanEval·MBPP에서 특히 큰 폭으로 향상된다. Dream보다 대부분의 일반·수학 벤치마크에서 앞서고(Dream은 HumanEval에서 우위), 디퓨전임에도 Qwen2.5 7B와 경쟁하며 MMLU·BBH·ARC-C·GSM8K에서 표 내 최고를 기록한다.
| 벤치마크 | iLLaDA 8B | LLaDA 8B | Dream 7B | Qwen2.5 7B |
|---|---|---|---|---|
| General Tasks | ||||
| MMLU | 71.6 | 65.5 | 67.0 | 76.6 |
| MMLU-Pro | 52.3 | 37.0 | 43.3 | 56.3 |
| MMLU-Redux | 76.4 | 68.9 | 76.3 | 75.7 |
| Mathematics & Science | ||||
| GSM8K | 89.0 | 77.5 | 81.0 | 91.6 |
| MATH | 56.7 | 42.2 | 39.2 | 75.5 |
| Code | ||||
| HumanEval | 65.9 | 49.4 | 55.5 | 84.8 |
| MBPP | 58.0 | 41.0 | 58.8 | 79.2 |
| 평균 | 67.1 | 54.5 | 60.2 | 77.1 |
SFT 후에도 iLLaDA는 대부분 벤치마크에서 LLaDA·Dream을 능가하며 GSM8K·MATH·HumanEval에서 향상이 두드러진다. Qwen2.5 7B 대비 여러 수학·코드 벤치마크에서 뒤지지만 MMLU-Redux에서는 경쟁력 있고 디퓨전과 강한 AR 베이스라인의 격차를 크게 좁힌다. iLLaDA-Base가 이미 Qwen2.5-Base와 경쟁하므로, 남은 지시 격차는 주로 Qwen2.5가 SFT 후 추가한 강화학습 정렬 때문으로 본다.
| 채점 규칙 | PIQA | ARC-C | HellaSwag |
|---|---|---|---|
| Likelihood | 77.2 | 60.2 | 74.3 |
| Confidence | 78.5 | 60.8 | 76.6 |
신뢰도 기반 채점이 우도형 베이스라인보다 일관되게 개선된다(PIQA +1.3, ARC-C +0.6, HellaSwag +2.3). 이 결과가 객관식 평가에 신뢰도 채점을 쓰는 근거다.
SFT 에폭이 늘수록 성능이 대체로 향상되며, 특히 추론 중심 벤치마크에서 긴 SFT를 뒷받침한다. GSM8K는 6에폭에서 일시적으로 떨어졌다가 9·12에폭에서 크게 오르고, MATH·MMLU-Pro는 단조 상승한다. 이는 데이터 제약 환경의 디퓨전 모델 연구와 일치한다 — 한 연구는 1B 고유 토큰을 96에폭 학습하는 극단적 반복에서도 디퓨전 모델이 계속 개선됨을 보였다. 사전학습보다 훨씬 작은 지시 코퍼스를 쓰는 SFT에도 유사한 데이터 재사용 효과가 나타난다. 연산 제약으로 12 에폭을 넘기지는 않았다.
처음부터 학습된 8B 완전 양방향 디퓨전 언어 모델 iLLaDA를 제시했다. 사전학습을 12T 토큰으로 확장하고 모델 설계, 학습률 스케줄, SFT 포맷, 신뢰도 기반 객관식 채점, 가변 길이 생성 등 실용 레시피의 여러 부분을 갱신했다. iLLaDA가 다중 에폭 SFT에서 계속 이득을 얻음도 확인했다. 베이스·지시 평가 전반에서 이 변경들은 일반·수학·코드 벤치마크에서 LLaDA 대비 실질적 개선으로 이어져, 처음부터의 완전 양방향 디퓨전 학습이 강한 언어 모델링 성능을 달성할 수 있음을 시사한다.
BBH·GSM8K·MATH·HumanEval·MBPP에 개방형 생성을 쓴다. BBH·GSM8K·MATH·MBPP는 최대 생성 길이 1024, 블록 길이 32로 설정한다. HumanEval은 준자기회귀 블록 샘플링이 성능을 해치는 것을 관찰해 최대 생성 길이와 블록 길이를 모두 512로 둔다.
벤치마크별 추론 설정을 쓴다. 답 글자 하나만 생성하면 되는 MMLU·MMLU-Redux는 최대 생성/블록 길이를 각각 4/4, 3/3으로 둔다. GSM8K·HumanEval은 2048/32, MMLU-Pro·MATH는 4096/32, MBPP는 2048/16으로 설정한다.
일부 어려운 문제에서 모델이 "Wait, let me check again" 같은 문구를 반복하며 최종 답을 못 내는 반복 추론 루프가 관찰된다. 이는 추론 모델이 생성한 구조화된 사고연쇄 흔적을 담은 SFT 코퍼스 일부에 기인한다고 본다. 완화를 위해 생성이 길어질수록 정지 사고 토큰 </think> 방출 확률을 점진적으로 높여, 모델이 추론 흔적을 종료하고 최종 답을 내도록 유도한다.