알쓸신잡
학습 시 두 개 이상의 데이터셋에서 batch 뽑아내는 법
재바기
2023. 4. 8. 18:52
728x90
종종 model의 input으로 두 개의 데이터가 들어갈 때가 있다.
따라서, dataloader도 각각 따로 필요할 수가 있고, 그로 인해 enumerate 함수의 인자를 어떻게 전달해야 할 지 헷갈릴 때가 있다.
그럴 때는 다음과 같이 enumerate안에 zip으로 두 dataloader를 묶어서 사용해보자.
model.train()
for epoch in range(num_epoch):
print('EPOCH {}:'.format(epoch + 1))
training_loss = 0.0
for i, data in enumerate(zip(train_dataloader1, train_dataloader2)):
# get the inputs; data is a list of [inputs, labels]
data1, data2 = data
input1, labels1 = data1
input2, labels2 = data2
이렇게 enumerate안에서 두 dataloader를 zip하면 enumerate가 각각의 dataloader에서 batch를 잘 stack해서 전달해준다.
다만, 주의할 점은 이 때 enumerate의 리턴값으로 나오는 출력값이 조금 차이가 있다.
i 와 같이 index의 경우는 그대로 0부터 시작하지만, data는 두 dataloader에서 나오는 data가 합쳐져 있다.
따라서 data를 data1과 data2로 나눠주고, 이를 다시 input과 label로 나눠주면 된다.
728x90