|  | import os | 
					
						
						|  | import numpy as np | 
					
						
						|  | from abc import abstractmethod | 
					
						
						|  | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Txt2ImgIterableBaseDataset(IterableDataset): | 
					
						
						|  | ''' | 
					
						
						|  | Define an interface to make the IterableDatasets for text2img data chainable | 
					
						
						|  | ''' | 
					
						
						|  | def __init__(self, num_records=0, valid_ids=None, size=256): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.num_records = num_records | 
					
						
						|  | self.valid_ids = valid_ids | 
					
						
						|  | self.sample_ids = valid_ids | 
					
						
						|  | self.size = size | 
					
						
						|  |  | 
					
						
						|  | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return self.num_records | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PRNGMixin(object): | 
					
						
						|  | """ | 
					
						
						|  | Adds a prng property which is a numpy RandomState which gets | 
					
						
						|  | reinitialized whenever the pid changes to avoid synchronized sampling | 
					
						
						|  | behavior when used in conjunction with multiprocessing. | 
					
						
						|  | """ | 
					
						
						|  | @property | 
					
						
						|  | def prng(self): | 
					
						
						|  | currentpid = os.getpid() | 
					
						
						|  | if getattr(self, "_initpid", None) != currentpid: | 
					
						
						|  | self._initpid = currentpid | 
					
						
						|  | self._prng = np.random.RandomState() | 
					
						
						|  | return self._prng | 
					
						
						|  |  |