This repository implements training and inference methods of GAN with just fc layers on MNIST.
For setting up the mnist dataset:
Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
The directory structure should look like this
$REPO_ROOT -> data -> train -> images -> 0 *.png -> 1 ... -> 9 *.png -> test -> images -> 0 *.png ... -> dataset -> tools - Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/GANs-Pytorch.gitcd GANs-Pytorchpip install -r requirements.txtpython -m tools.train_ganfor training and saving inference samples
- Ensure dataset is prepared according to the Data Preparation instructions
- Change the
IM_CHANNELSfield to 3 intrain_gan.py - Uncomment lines 56-59 in the
dataset/mnist_dataset.pyfile
- Dump all *.png files(or whatever format images you have) in the path
data/train/images - Comment https://github.com/explainingai-code/GANs-Pytorch/blob/main/dataset/mnist_dataset.py#L43
- Directory structure should be following:
data -> train -> images *.png - Change the
IM_PATHfield todata/trainintrain_gan.py - Change the channels and image sizes accordingly
Outputs will be saved every 50 steps in samples directory .
During training of GAN the following output will be saved
- Latest Model checkpoints for generator and discriminator in
$REPO_ROOTdirectory
During inference every 50 steps the following output will be saved
- Sampled image grid for in
samples/*.png
@misc{goodfellow2014generative, title={Generative Adversarial Networks}, author={Ian J. Goodfellow and Jean Pouget-Abadie and Mehdi Mirza and Bing Xu and David Warde-Farley and Sherjil Ozair and Aaron Courville and Yoshua Bengio}, year={2014}, eprint={1406.2661}, archivePrefix={arXiv}, primaryClass={stat.ML} }