Table of Contents

DeepLearning

torch 모델 구조 graph로 그리기

꼬꼬마코더 2024. 7. 26. 00:59
728x90

1. vscode에서 Ctrl + ~ : bash창에서 graphviz 설치

apt-get install graphviz




2. 다음 코드 실행

import torch
import torchvision.models as models
import torchviz

# 모델과 입력 데이터를 같은 디바이스로 이동
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18().to(device)
x = torch.randn(1, 3, 224, 224).to(device).requires_grad_(True)

# 모델 실행
y = model(x)

# 계산 그래프 시각화
dot = torchviz.make_dot(y, params=dict(list(model.named_parameters()) + [('input', x)]))
dot.render("model_graph", format='png')  # 그래프 이미지 파일로 저장




3. 그림과 같이 png 파일 생성됨