Table of Contents
728x90
테스트 시간 증강(Test-Time Augmentation, TTA)은 머신 러닝 모델의 강건성과 성능을 개선하기 위해 사용하는 기법으로, 특히 컴퓨터 비전 작업에서 유용합니다. TTA는 테스트 단계에서 테스트 이미지에 다양한 증강을 적용하여 각 증강된 버전에 대해 예측을 수행하고, 이러한 예측을 결합(예: 평균화)하여 최종 출력을 얻는 방식입니다. 이 방법은 데이터의 노이즈와 변동성을 완화하여 더 신뢰할 수 있고 정확한 예측을 가능하게 합니다.
다음은 TTA 프로세스에 대한 자세한 설명입니다:
1. TTA 변환 정의
- 각 테스트 이미지에 적용할 증강 세트를 정의합니다.
- 이러한 변환에는 좌우 반전, 상하 반전, 회전, 밝기 조정, 스케일링 등이 포함될 수 있습니다.
2. TTA 변환 적용
- 각 테스트 이미지에 대해 정의된 변환을 사용하여 여러 증강된 버전을 생성합니다.
- 일반적으로 각 변환을 한 번씩 적용하여 원본 이미지의 다양한 변형 세트를 만듭니다.
3. 증강된 이미지에 대한 예측 수행
- 증강된 각 이미지에 대해 모델을 사용하여 예측을 수행합니다.
- 각 증강된 이미지 버전에 대해 모델의 출력을 얻습니다.
4. 예측 결합
- 각 증강된 이미지에 대한 예측을 결합하여 최종 예측을 만듭니다.
- 결합 방법은 평균, 가중 평균 또는 다수결 투표 방식이 될 수 있습니다.
5. 최종 예측 생성
- 결합된 예측을 사용하여 최종 출력을 생성합니다.
아래는 TTA를 적용한 코드의 예입니다:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, HorizontalFlip, VerticalFlip, Rotate, LongestMaxSize, PadIfNeeded, Resize
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import pandas as pd
import os
from PIL import Image
# Albumentations transform 정의
pre_img_size = 384 # 사전 정의된 이미지 크기 (필요에 따라 수정)
img_size = 256
# TTA 변환 정의
tta_transforms = [
Compose([HorizontalFlip(p=1.0), ToTensorV2()]),
Compose([VerticalFlip(p=1.0), ToTensorV2()]),
Compose([Rotate(limit=45, p=1.0), ToTensorV2()]), # 각도 0도에서 45도 사이에서 무작위 회전
Compose([LongestMaxSize(max_size=pre_img_size, always_apply=True), PadIfNeeded(min_height=pre_img_size, min_width=pre_img_size, border_mode=0, value=(255, 255, 255)), ToTensorV2()]),
Compose([Resize(height=img_size, width=img_size), ToTensorV2()]),
Compose([ToTensorV2()]) # 원본 이미지
]
class ImageDataset2(Dataset):
def __init__(self, csv_file, path, transform=None):
self.df = pd.read_csv(csv_file)
self.path = path
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
name = self.df.iloc[idx, 0]
target = self.df.iloc[idx, 1]
img = np.array(Image.open(os.path.join(self.path, name)).convert("RGB"))
if self.transform:
img = self.transform(image=img)['image']
return img, target
tta_dataset = ImageDataset2(
"/data/ephemeral/home/datasets_fin/sample_submission.csv",
"/data/ephemeral/home/datasets_fin/test/",
transform=Compose([ToTensorV2()])
)
tta_loader = DataLoader(
tta_dataset,
batch_size=32,
shuffle=False,
num_workers=0,
pin_memory=True
)
def tta_inference(loader, model, device, tta_transforms):
model.eval()
all_outputs = []
for images, _ in tqdm(loader):
images = images.to(device).float()
batch_outputs = torch.zeros(images.size(0), 17).to(device) # 17은 클래스 수
for tta_transform in tta_transforms:
# 각 TTA 변형 적용
tta_images = torch.stack([torch.tensor(tta_transform(image=image.permute(1, 2, 0).cpu().numpy())['image']).to(device) for image in images])
with torch.no_grad():
preds = model(tta_images)
batch_outputs += preds
# TTA 평균 내기 (Soft Voting)
batch_outputs /= len(tta_transforms)
# 최종 예측 값을 리스트에 저장
all_outputs.append(batch_outputs.cpu().numpy())
# 모든 배치의 예측 값을 연결
all_outputs = np.concatenate(all_outputs, axis=0)
return all_outputs
# TTA를 적용한 예측
all_outputs = tta_inference(tta_loader, model, device, tta_transforms)
preds_list = np.argmax(all_outputs, axis=1)
# 예측 결과를 데이터프레임으로 저장
pred_df = pd.DataFrame(tta_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list
# 제출 형식 파일을 읽어와 ID 열이 일치하는지 확인
sample_submission_df = pd.read_csv("/data/ephemeral/home/datasets_fin/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()
# 예측 결과를 CSV 파일로 저장
pred_df.to_csv(f"{model_name}_{img_size}SIZE_{BATCH_SIZE}BATCH_{EPOCHS}EPOCH_{augment_ratio}_TTA_pred.csv", index=False)
print(pred_df.head())
이 코드는 TTA를 적용하여 모델의 예측 성능을 향상시키는 과정을 보여줍니다. tta_inference
함수는 각 이미지에 대해 여러 증강을 적용하고, 각각의 증강된 이미지에 대해 예측을 수행한 후 평균을 내어 최종 예측을 생성합니다.
'DeepLearning' 카테고리의 다른 글
GPU 메모리memory확인 및 process kill (0) | 2024.08.10 |
---|---|
Global Average Pooling (GAP), Adaptive Average Pooling (0) | 2024.08.09 |
Normalize (0) | 2024.08.07 |
이미지 크기 (0) | 2024.08.07 |
submission 점수가 낮게 나오는 이유 (0) | 2024.08.07 |
250x250
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- recursion #재귀 #자료구조 # 알고리즘
- #패스트캠퍼스 #패스트캠퍼스ai부트캠프 #업스테이지패스트캠퍼스 #upstageailab#국비지원 #패스트캠퍼스업스테이지에이아이랩#패스트캠퍼스업스테이지부트캠프
- Python
- 손실함수
- #패스트캠퍼스 #패스트캠퍼스AI부트캠프 #업스테이지패스트캠퍼스 #UpstageAILab#국비지원 #패스트캠퍼스업스테이지에이아이랩#패스트캠퍼스업스테이지부트캠프
- classification
- PEFT
- #패스트캠퍼스 #UpstageAILab #Upstage #부트캠프 #AI #데이터분석 #데이터사이언스 #무료교육 #국비지원 #국비지원취업 #데이터분석취업 등
- LIST
- git
- 리스트
- 티스토리챌린지
- cnn
- Array
- LLM
- Hugging Face
- Transformer
- nlp
- 오블완
- 코딩테스트
- Numpy
- speaking
- Github
- clustering
- t5
- 파이썬
- English
- Lora
- RAG
- 해시
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
글 보관함