알쓸신잡

학습 시 두 개 이상의 데이터셋에서 batch 뽑아내는 법

재바기 2023. 4. 8. 18:52
728x90

종종 model의 input으로 두 개의 데이터가 들어갈 때가 있다.

따라서, dataloader도 각각 따로 필요할 수가 있고, 그로 인해 enumerate 함수의 인자를 어떻게 전달해야 할 지 헷갈릴 때가 있다.

 

그럴 때는 다음과 같이 enumerate안에 zip으로 두 dataloader를 묶어서 사용해보자.

    model.train()
    for epoch in range(num_epoch):
        print('EPOCH {}:'.format(epoch + 1))
        training_loss = 0.0
        for i, data in enumerate(zip(train_dataloader1, train_dataloader2)):
            # get the inputs; data is a list of [inputs, labels]
            data1, data2 = data
            input1, labels1 = data1
            input2, labels2 = data2

이렇게 enumerate안에서 두 dataloader를 zip하면 enumerate가 각각의 dataloader에서 batch를 잘 stack해서 전달해준다.

 

다만, 주의할 점은 이 때 enumerate의 리턴값으로 나오는 출력값이 조금 차이가 있다.

i 와 같이 index의 경우는 그대로 0부터 시작하지만, data는 두 dataloader에서 나오는 data가 합쳐져 있다.

따라서 data를 data1과 data2로 나눠주고, 이를 다시 input과 label로 나눠주면 된다.

728x90