Table of Contents
T5 모델에서의 프롬프트 튜닝(prompt tuning)은 미세 조정과 달리, 모델의 모든 파라미터를 고정하고 학습 가능한 프롬프트 벡터만 조정하는 방식입니다. 이를 통해 모델의 전체 구조는 변하지 않으면서도 특정 작업에 맞게 성능을 최적화할 수 있습니다. 아래는 transformers
라이브러리를 사용해 T5 모델에서 프롬프트 튜닝을 하는 예시 코드입니다.
1. 설치 준비
프롬프트 튜닝을 하려면 Hugging Face의 transformers
와 datasets
라이브러리를 설치해야 합니다. 먼저 아래 명령어로 설치하세요:
pip install transformers datasets
2. T5 프롬프트 튜닝 예시 코드
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
# 1. 모델과 토크나이저 준비
model_name = "t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
# 2. 데이터셋 불러오기
dataset = load_dataset("cnn_dailymail", "3.0.0", split='train[:1%]') # 작은 서브셋 사용
train_data = dataset.map(lambda x: tokenizer(x['article'], padding='max_length', truncation=True, return_tensors="pt", max_length=512), batched=True)
# 3. 프롬프트 튜닝 파라미터 추가
# 여기서 프롬프트 벡터를 추가하고 학습할 수 있도록 세팅합니다.
prompt_length = 10 # 학습할 프롬프트의 길이
prompt_embedding = torch.nn.Parameter(torch.randn(prompt_length, model.config.d_model))
# 4. 모델의 파라미터는 고정하고 프롬프트 벡터만 학습하도록 설정
for param in model.parameters():
param.requires_grad = False
# 프롬프트 벡터는 학습 가능하게 설정
optimizer = torch.optim.Adam([prompt_embedding], lr=1e-3)
# 5. 학습 루프 설정
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=1e-4,
per_device_train_batch_size=4,
num_train_epochs=3,
weight_decay=0.01,
)
# 6. Trainer로 학습 시작
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
optimizers=(optimizer, None),
)
trainer.train()
핵심 설명:
- 프롬프트 벡터:
prompt_embedding
으로 학습할 벡터를 정의하고, 이를 학습 가능한 파라미터로 만듭니다. - 고정된 모델 파라미터: 모델의 모든 파라미터는 학습되지 않도록
param.requires_grad = False
로 고정하고, 오직 프롬프트 벡터만 학습되도록 설정합니다. - Trainer 사용: Hugging Face의
Trainer
클래스를 사용하여 손쉽게 학습 루프를 설정하고 실행합니다.
위 코드는 프롬프트 튜닝의 기본 개념을 보여주는 간단한 예시입니다. 실제 작업에서는 프롬프트 임베딩을 입력과 결합하는 방식이나 여러 추가적인 기법을 활용해 성능을 더 최적화할 수 있습니다.
LLM(대규모 언어 모델)에서는 프롬프트 튜닝이 꼭 필요하지는 않지만, 특정 상황이나 목적에 따라 효율적인 선택이 될 수 있습니다. LLM은 기본적으로 매우 강력한 성능을 제공하기 때문에 별도의 튜닝 없이도 다양한 작업을 수행할 수 있지만, 프롬프트 튜닝이 적합한 경우가 있습니다.
LLM에서 프롬프트 튜닝이 필요한 경우
특정 작업에 대한 성능 향상:
- LLM은 범용적인 모델이기 때문에 모든 작업에서 최적의 성능을 내지는 않습니다. 특정 도메인이나 작업(예: 의료, 법률, 또는 대화형 요약 작업)에서 더 높은 성능을 원한다면 프롬프트 튜닝을 통해 해당 작업에 더 적합한 출력을 유도할 수 있습니다.
작업에 적합한 결과 유도:
- LLM은 특정 작업에 적합한 출력을 생성하도록 미세 조정 없이 사용하면, 예상과 다른 결과를 낼 수 있습니다. 예를 들어, 특정 질문에 대한 답변을 형식화된 구조로 출력하고 싶을 때, 프롬프트 튜닝을 통해 모델이 일관되게 답변을 생성하도록 할 수 있습니다.
데이터 라벨이 적을 때:
- LLM을 사용하는 많은 상황에서는 충분한 양의 학습 데이터를 구하기 어려울 수 있습니다. 이 경우, 모델 전체를 미세 조정하지 않고 프롬프트 튜닝을 통해 일부 프롬프트 벡터만 학습하여 소량의 데이터로도 좋은 성능을 낼 수 있습니다.
메모리와 시간 절약:
- LLM을 전체 미세 조정하는 것은 시간이 오래 걸리고 많은 자원을 소모합니다. 반면, 프롬프트 튜닝은 모델 파라미터를 고정하고 입력 프롬프트에만 벡터를 추가하므로, 메모리 사용량을 절약하면서 더 빠르게 학습할 수 있습니다. 이 방법은 리소스가 제한된 환경에서 특히 유용합니다.
모델의 보존:
- 어떤 상황에서는 기본 모델의 가중치를 수정하지 않고 유지해야 할 때가 있습니다. 예를 들어, 사전 학습된 LLM을 여러 작업에 맞게 응용하고 싶을 때, 모델을 변경하지 않고 프롬프트 튜닝을 사용해 여러 작업에 맞는 최적의 성능을 이끌어낼 수 있습니다. 이를 통해 모델의 범용성을 유지하면서도 특정 작업에서 최적의 성능을 낼 수 있습니다.
프롬프트 튜닝이 불필요할 수 있는 경우
범용 성능이 충분할 때:
- LLM은 기본적으로 다양한 범용 작업에서 좋은 성능을 보입니다. 만약 기본적인 질문 응답이나 텍스트 생성 작업에서 이미 충분히 좋은 결과를 얻고 있다면, 굳이 프롬프트 튜닝을 추가로 할 필요는 없습니다.
미세 조정으로 충분할 때:
- 대규모 데이터를 가지고 있고, 특정 작업에 최적화된 모델이 필요하다면, 전체 모델을 미세 조정(Fine-Tuning)하는 것이 더 좋은 성능을 낼 수 있습니다. 이 경우, 프롬프트 튜닝보다는 모델의 모든 가중치를 학습시키는 것이 더 효과적일 수 있습니다.
복잡한 작업일 때:
- 프롬프트 튜닝은 모델의 가중치를 변경하지 않고도 성능을 향상시킬 수 있지만, 특정 작업에서는 이 방법으로는 충분하지 않을 수 있습니다. 모델의 심층적인 이해와 조정이 필요한 작업(예: 매우 복잡한 논리적 추론)에서는 프롬프트 튜닝만으로는 한계가 있을 수 있습니다.
결론:
LLM의 경우, 프롬프트 튜닝이 꼭 필요한 것은 아니지만, 특정 작업에서 성능을 최적화하거나 자원 절약을 목적으로 할 때 매우 유용할 수 있습니다. 따라서, 프롬프트 튜닝을 선택할지는 작업의 특성과 목표에 따라 달라집니다.
- 범용 작업이나 간단한 텍스트 생성에서는 기본 LLM 성능으로도 충분하지만, 특정 도메인 최적화나 리소스 절약이 필요할 때 프롬프트 튜닝은 효율적인 대안이 될 수 있습니다.
'DeepLearning > NLP' 카테고리의 다른 글
IA3 (Input-Activated Attention Adaptation) (1) | 2024.09.19 |
---|---|
LoRA(Low-Rank Adaptation)와 프롬프트 튜닝(Prompt Tuning) 차이 (0) | 2024.09.19 |
[논문리뷰] Scaling Laws for Neural Language Models (0) | 2024.09.19 |
[LLM] LLM 모델이 LM 모델과 달라진 점 (0) | 2024.09.19 |
[LLM] LM에서 LLM으로 발전하는 과정에서의 주요 변화 (0) | 2024.09.19 |
- Total
- Today
- Yesterday
- English
- classification
- Array
- LLM
- speaking
- 코딩테스트
- RAG
- 티스토리챌린지
- nlp
- Python
- Hugging Face
- Transformer
- #패스트캠퍼스 #패스트캠퍼스AI부트캠프 #업스테이지패스트캠퍼스 #UpstageAILab#국비지원 #패스트캠퍼스업스테이지에이아이랩#패스트캠퍼스업스테이지부트캠프
- 파이썬
- Lora
- 오블완
- 손실함수
- clustering
- #패스트캠퍼스 #UpstageAILab #Upstage #부트캠프 #AI #데이터분석 #데이터사이언스 #무료교육 #국비지원 #국비지원취업 #데이터분석취업 등
- LIST
- Numpy
- #패스트캠퍼스 #패스트캠퍼스ai부트캠프 #업스테이지패스트캠퍼스 #upstageailab#국비지원 #패스트캠퍼스업스테이지에이아이랩#패스트캠퍼스업스테이지부트캠프
- 해시
- Github
- PEFT
- t5
- git
- cnn
- 리스트
- recursion #재귀 #자료구조 # 알고리즘
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |