본문 바로가기

Machine Learning/Deep Running

Pytorch 1.7 + colab TPU 병렬처리 사용법

colab 사용시 요약 BERT 돌리는데 자꾸 oom 문제 (out of memory)가 발생하여 Gpu 버전에서 TPU버전으로 바꿀때 바꿀것들

Pytorch는 기본적으로 GPU로 돌아가지만 xla라는 라이브러리를 이용하면 TPU를 사용가능하다.

별로 친절한 설명은 아니지만 참고바랍니다.

라이브러리 추가 :

import torch

import torch_xla

import torch_xla.utils.utils as xu

import torch_xla.core.xla_model as xm

import torch_xla.utils.serialization as xser

import torch_xla.distributed.parallel_loader as pl

import torch_xla.distributed.xla_multiprocessing as xmp

플래그 호출 :

flags = {}

flags['batch_size'] = 256

flags['num_workers'] = 4

flags['num_epochs'] = 100

flags['seed'] = 1234

flags['num_cores'] = 8

메소드 호출 시 :

real_method = xmp.spawn(method, args=(flags,), nprocs=8, start_method = 'fork') 

메소드 선언시 :

def method(self,index,flags) :

기존에 init에서 선언한 변수가 있기때문에 self도 괄호안에 선언

index는 코어 number

디바이스 선언 :

device = xm.xla_device()

멀티프로세싱 로더 :

para_loader = pl.ParallelLoader(dataloader, [device])  -

dataloader는 data_loader 메소드를 이용한다.

dataloader = data_loader(self.args, self.tensors, batch_size, flags)
iter_bar = tqdm(

para_loader.per_device_loader(device),

total = len(para_loader.per_device_loader(device)),

desc = ' lter (Loss:X.XXX LR:X.XXX)'

)

모델 저장 시 :

xm. save (self.model.module.state_dict(), epoch_file)