2023. 2. 18. 00:15ㆍ🧡 Programming/💻 Python
모델을 만들면 항상 커스텀 데이터셋을 만들게 된다.
왜냐면 나는 이미지 폴더로 데이터로드 하는 것을 싫어한다.
아래와 같이 사용하는 것인데 이미지 폴더를 바로 가져와서 데이터셋을 구성하는 편리함이 있긴 하지만
데이터셋 내에서 데이터를 다양하게 조합하고 싶을 때마다 데이터를 매번 새로운 폴더로 만들 수는 없다.
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
dataset = ImageFolder(root='folder_path',
transform=transforms.Compose([
transforms.ToTensor(),
]))
data_loader = DataLoader(dataset,
batch_size=128,
shuffle=True,
num_workers=8
)
그래서 사실 나는 데이터를 텍스트 파일이든 뭐든 파일 형태로 보관하는 것을 좋아한다.
그러면 학습 버전에 따라서 데이터 리스트를 관리할 수 있고 학습 데이터도 바로 파일에서 꺼내와서 사용하면 되기 때문이다.
꼭 텍스트파일로 구성하지 않더라도 데이터 루트 폴더내에서 데이터셋을 어떤 조건에 맞게 가져오면서 데이터로더를 구성 할 수도 있다.
이렇게 이용하려면 커스텀 데이터셋을 구성해야 한다 .
https://tutorials.pytorch.kr/beginner/data_loading_tutorial.html
파이토치 공식 사이트에서도 커스텀 데이터셋과 데이터로더를 구성하는 예제를 제공하고 있다.
💡 Custom Dataset 작성하기
class CustomDataset(torch.utils.data.Dataset):
def __init__(self): #데이터셋의 전처리
def __len__(self): # 데이터셋 길이, 총 샘플의 수를 적어주는 부분
def __getitem__(self, idx): # 데이터셋에서 특정 1개의 샘플을 가져오는 함수
여기서 torch.utils.data.Dataset은 파이토치에서 데이터셋을 제공하는 추상 클래스이다. Dataset을 상속받아서 3개 메소드들을 오버라이드하여 커스텀 데이터셋을 작성하면 된다.
__init__(self):
데이터를 정의할 때 필요한 변수들을 불러오면 된다. 보통 이미지 폴더와 transform 정도를 불러오고 그 외에는 본인의 필요 용도에 맞게 커스텀 하면 된다.
__getitem__(self,index):
전체 데이터 중에서 idx번째 데이터를 가져오는 함수이다.
__len__(self):
사용하는 데이터들 중 아무거나 길이를 반환하면 된다. 이미지의 길이든 라벨의 길이든 상관없다.
Example
커스텀 데이터셋을 이용한 예시를 보자
데이터셋이 각각 train/val 폴더로 나누어져 있고 각각의 폴더에 json 형식으로 이미지 이름당 라벨이 저장되어 있다고 가정하면 다음과 같이 코드를 사용 할 수 있다.
꼭 아래와 같은 형식을 사용 할 필요는 없다. 각자의 모델과 학습 방식에 따라서 자유롭게 구성하면 된다.
예를 들어서 내가 train 폴더 내에서 이미지 전체를 가져와서 그 중 원본 이미지 사이즈가 100x100 이하인 데이터는 제외시키거나, 랜덤으로 이미지의 10프로는 학습에서 제외한다거나 다양한 방법으로 데이터셋을 구성 할 수 있다.
import os
import json
from glob import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self,root_dir,mode,transform=None):
"""
mode : train/val을 구분하기위해 편의상 만든 변수, 필요없다면 삭제 가능
"""
self.root_dir = root_dir
self.transform = transform
self.imglist,self.labellist = [],[]
self.mode = mode
"""
root_dir + 'train'/'val' 형태로 폴더가 저장되어 있을 경우 사용 가능
label 이 json 파일에 저장되어 있는 경우
"""
self.data_path = os.path.join(self.root_dir,self.mode)
self.caption_path = os.path.join(self.root_dir,self.mode+'.json')
self.imglist = glob(self.data_path + '/*')
with open(self.caption_path,'r') as file:
data = json.load(file)
for imgname in self.imglist:
self.labellist.append(data[imgname])
assert len(self.imglist) == len(self.labellist)
def __len__(self):
return len(self.imglist)
def __getitem__(self,idx):
image = Image.open(self.imglist[idx])
if self.transform:
image = self.transform(image)
label = self.labellist[idx]
return image,label
'🧡 Programming > 💻 Python' 카테고리의 다른 글
[OpenCV] 이미지 ROI 마스크 적용 방법들, 컬러 마스크 중복 적용 하기 (x연산, bitwise, copyTo, addWeighted) (0) | 2023.02.18 |
---|---|
[Matplotlib] 이미지 여러 장 plot 할 때 그리드,칸 나누기 & 그리드,칸 합치기(subplot,GridSpec,add_subplot) (0) | 2023.02.10 |