Skip to content

Releases: PrasannaPulakurthi/MMD-AdversarialNAS-GAN

Initial public release

Choose a tag to compare

@PrasannaPulakurthi PrasannaPulakurthi released this 03 Sep 17:51
5b6bf82

MMD-AdversarialNAS-GAN is an pipeline that improves GANs via MMD-guided Adversarial Neural Architecture Search (NAS) and tensor-decomposition–based compression for efficient deployment. This release corresponds to the ICASSP 2024 codebase and includes search, training, evaluation, and compression scripts.


Highlights

  • Adversarial NAS (MMD-guided): Search generator/discriminator designs using an MMD objective.
  • Model compression: Tensor decomposition + fine-tuning to cut parameters/compute with minimal quality loss.
  • Full pipeline included: search → train → evaluate → compress with ready-to-run scripts and example configs.
  • Reproducibility: Deterministic seeds, documented checkpoints, and FID statistic files for consistent evaluation.

What’s included

  • Core training/eval/search code: MGPU_search_arch.py, MGPU_train_arch.py, MGPU_test_arch.py.
  • Compression utilities: decompose.py, MGPU_cpcompress_arch.py, MGPU_test_cpcompress.py, MGPU_layersinfo_arch.py.
  • Experiment scaffolding & genotypes: exps/arch_cifar10/Genotypes/…
  • Datasets/utilities: datasets.py, celeba.py, utils/, archs/, architect.py, network.py, cfg*.py, scripts/.

Getting started (summary)

  1. Environment
    • Python 3.9 is recommended.
    • pip install -r requirements.txt
    • Example (CUDA 11.6) PyTorch install:
      pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
  2. Artifacts
    • Download pretrained models to ./exps/...
    • Download FID statistics to ./fid_stat/...
  3. Run
    • Search (example): scripts/search_arch_cifar10.sh → calls MGPU_search_arch.py
    • Train (example): scripts/train_arch_cifar10_large.sh → calls MGPU_train_arch.py
    • Evaluate (example): scripts/test_arch_cifar10.sh → calls MGPU_test_arch.py
    • Compress (examples in README) → calls MGPU_cpcompress_arch.py and MGPU_test_cpcompress.py

Tip: Ensure folder names match the paths expected by the scripts (e.g., checkpoints under exps/<exp_name>/Model).


Compatibility

  • Tested with Python 3.9 and PyTorch 1.13.x (CUDA builds supported).
  • Example dataset loaders provided for CIFAR-10 and CelebA.

Notes & limitations

  • Some steps rely on external downloads (pretrained models, FID stats). Keep the directory structure intact.
  • Multi-GPU scripts are provided; adjust batch sizes/workers for your hardware.

Cite this work

If this code helps your research, please cite the ICASSP 2024 paper and the extended IEEE Access 2024 article. (BibTeX and links are provided in the repository README.)


Changelog

  • First official release of the ICASSP-aligned code:
    • Added NAS search, training, testing, and compression flows.
    • Included dataset loaders, experiment scaffolding, and helper scripts.
    • Documented environment setup and artifact download locations.