Haru's 개발 블로그

[논문 리뷰] WizardMath-70B-V1.0 논문 리뷰 본문

논문 리뷰/Language Model

[논문 리뷰] WizardMath-70B-V1.0 논문 리뷰

Haru_29 2024. 4. 19. 11:39

링크

 

GitHub - nlpxucan/WizardLM: LLMs build upon Evol Insturct: WizardLM, WizardCoder, WizardMath

LLMs build upon Evol Insturct: WizardLM, WizardCoder, WizardMath - nlpxucan/WizardLM

github.com

 

들어가기에 앞서

최근에는 GPT-4와 같은 대형 언어 모델 (LLM)이 자연어 처리 (NLP) 작업에서 놀라운 성능을 보여주고 있으며, 이에는 어려운 수학적 추론도 포함됩니다. 다만, 이는 대부분의 기존 오픈 소스 모델은 대규모 인터넷 데이터에서만 사전 훈련되었으며 수학 관련 최적화가 이루어지지 않았습니다. 본 논문에서는 우리가 제안한 "Reinforcement Learning from Evol-Instruct Feedback (RLEIF)" 방법을 사용하여 Llama-2의 수학적 추론 능력을 향상시키는 WizardMath를 제시합니다.

개요

최근에는 대규모 언어 모델 (LLM)이 상당한 관심을 받아 다양한 자연어 처리 (NLP) 작업에 대한 주요 접근 방식이 되었습니다.

위의 그림을 확인해보면 기존의 ChatGPT는 답변을 바로 내는 반면, Wizard가 사용한 프롬프트인 Chain-of-thought (CoT)는 단계별 해결책을 생성하기 위해 더 나은 프롬프트를 디자인하는 것을 제안하여 성능 향상을 이끌어 낼 수 있습니다. 또한, Self-Consistency는 또한 모델에서 여러 가능한 답변을 생성하고 다수 투표를 기반으로 올바른 답변을 선택하는 것에 기초한 많은 추론 벤치마크에서 높은 성능을 달성합니다. 최근에는 어려운 수학 문제를 해결하기 위해 결과 지도보다 과정 지도 강화 학습을 사용할 때 현저한 성능을 달성한다고 발견되었습니다.

Evol-Instruct와 Process-supervised Reinforcement Learning에서 영감을 받아 본 연구는 SOTA(최고 수준) 오픈 소스 LLM인 Llama-2의 수학적 추론 능력을 향상시키기 위해 목표를 설정합니다. 위의 그림에 나와 있는 것처럼 우리는 Reinforcement Learning from Evol-Instruct Feedback (RLEIF)라는 새로운 방법을 제안합니다. 이 방법은 먼저 수학 특화 Evol-Instruct를 사용하여 다양한 수학 지시 데이터를 생성한 다음 지시 보상 모델 (IRM) 및 프로세스 지도 보상 모델 (PRM)을 교육합니다. 전자는 진화된 지시의 품질을 나타내며 후자는 솔루션의 각 단계에 대한 피드백을 받습니다.

 

방법

본 논문에서는 3가지 단계를 적용합니다.

  1. Supervised fine-tuning(지도된 세밀 조정)
  2. 지시 보상 모델 및 프로세스 지도 보상 모델 교육
  3. Active Evol-Instrct와 PPO training

1. 지도된 세밀 조정

InstructGPT를 따르면, 우리는 먼저 기본을 지도된 지시-응답 쌍으로 세밀하게 조정합니다.

  • 각 단계의 파싱을 쉽게 만들기 위해 GSM8k와 MATH에 대하여 70B 모델을 활용하여 15,000개의 답변을 few-shot으로 다시 생성하여 단계별 형식으로 솔루션을 생성한 뒤 올바른 답변이 있는 것을 찾아 이 데이터를 Llama 모델에 fine-tuning을 진행합니다.
  • 모델이 신경망 및 다양한 지시를 준수하는 능력을 향상시키기 위해 WizardLM의 훈련 데이터에서 1,500개의 오픈 도메인 대화를 샘플링을 진행한 뒤, 이를 위의 말뭉치와 병합하여 최종 SFT 훈련 데이터로 사용합니다.

2. 수학을 위한 Evol-Instuct 원칙

본 작업은 사전 훈련된 LLM을 향상시키기 위해 다양한 복잡성과 다양성을 가진 수학 지시를 만들려고 시도합니다. 구체적으로, Evol-Instruct를 다음과 같은 두 가지 진화 라인을 포함하는 새로운 패러다임으로 적용시킵니다.

  • Downward evolution : 질문을 더 쉽게 만들어 지시를 강화합니다. 예를 들어 고난도 질문을 낮은 난이도로 수정하거나, 다른 주제의 새로운 쉬운 질문을 생성합니다.
  • Upward evolution : 원래의 Evol-Instruct 방법에서 파생되었으며 더 많은 제약 추가, 구체화, 추론 증가를 통해 새롭고 어려운 질문을 생성하여 심화합니다.

3. Reinforcement Learning from Evol-Instruct Feedback (RLEIF)

  • 지시 보상 모델(IRM) : 이 모델은 진화된 지시의 품질을 판단하는 데 중점을 둡니다. 이때 품질은 정의, 정확도, 통합성 측면에서 측정을 진행합니다. IRM의 랭킹 리스트 훈련 데이터를 생성하기 위해 각 지시에 대해 ChatGPT와 Wizard-E를 사용하여 2~4개의 진화된 지시를 생성한 뒤, Wizard-E를 사용하여 해당 4~8개의 지시의 품질의 우선 순위를 매깁니다.
  • 프로세스 보상 모델(PRM) : 이 작업 이전에 강력한 오픈 소스 수학 추론 LLM이 없었기 때문에, 전문적인 인간 라벨러와 클로즈 소스 ChatGPT 없이는 고도로 정확한 프로세스 지도를 지원하는 간단한 방법이 없었습니다. 따라서 우리는 ChatGPT에 프로세스 지도를 제공하도록 의존하고, ChatGPT에 우리 모델이 생성한 솔루션의 각 단계의 정확성을 평가하도록 합니다.
  • PRO 훈련 : 원래 수학(GSM8k + MATH) 지시를 8회 진화시켜 데이터 크기를 15k에서 96k로 증가시킵니다.

결과

1. 기존의 모델들과 비교

Closed-source 모델들과 비교하면 GPT-4와 Claude2를 제외하면 성능이 높으며 Open-source와 비교를 하게 되면 WizardMath가 성능이 높다라는 것을 확인 할 수 있습니다.

추가 ) OpenAI의 GPT-3, GPT-3.5, ChatGPT5, GPT-4; Google의 PaLM 2, PaLM, Minerva; Anthropic의 Claude Instant, Claude 1.3, Claude 27 DeepMind의 Chinchilla

2. 평가 벤치마크

주로 WizardMath를 두 가지 벤치마크 (GSM8k 및 MATH)에서 평가합니다. GSM8k 데이터셋은 주로 초등 수학 문제를 다루고 있으며, 약 7500개의 교육 데이터와 1319개의 테스트 데이터를 포함하며, 각 문제는 기본 산술 연산(덧셈, 뺄셈, 곱셈 및 나눗셈)으로 이루어져 있으며 일반적으로 2에서 8단계가 필요합니다. MATH 데이터셋은 AMC 10, AMC 12 및 AIME와 같은 명문 수학 대회에서 문제를 수집하였으며, Prealgebra, Algebra, Number Theory, Counting and Probability, Geometry, Intermediate Algebra, 그리고 Precalculus라는 일곱 가지 학문 영역의 7500개의 교육 데이터와 5000개의 어려운 테스트 데이터를 포함하고 있습니다. 또한 이러한 문제는 난이도에 따라 '1'이 상대적으로 낮은 난이도를 나타내고 '5'가 가장 높은 난이도를 나타냅니다.

추가) AMC8 범위(확률, 추정, 백분율, 피타고라스 이론을 포함한 기하학, 공간 시각화, 일상적 응용 및 그래프 해석), AMC10(초등대수학, 기초기하학), AMC12(고등학교 전과정)

3. 훈련 및 평가 프롬프트

Llama 2베이스를 우리의 기본 모델로 사용합니다. 우리는 GSM8k 및 MATH 벤치마크를 평가하기 위해 다음과 같은 CoT 프롬프트를 사용합니다.

Below is an instruction that describes a task. Write a
response that appropriately completes the request.\n\n###
Instruction:\n{instruction}\n\n### Response: Let’s think step by step.

4. GSM8k 및 MATH 평가

WizardMath 70B는 GSM8k에서 일부Close-Source Models인 LLMs, ChatGPT, Claude Instant 및 PaLM 2 540B를 약간 능가합니다. 그리고 현재 모델이 모든 모델 중 상위 다섯에 속하는 것으로 나타납니다. 동시에 WizardMath 70B는 MATH에서 Text-davinci-002를 뛰어넘는 것으로 나타납니다.

Open-Source Models에서는 WizardMath 70B가 GSM8k 및 MATH 벤치마크 양쪽에서 모든 오픈 소스 모델에 비해 상당한 성능 우위를 나타낸다는 것을 명확히 보여줍니다.

 

Comments