diff --git a/code/analyze.py b/code/analyze.py index c965dd7..13540e7 100644 --- a/code/analyze.py +++ b/code/analyze.py @@ -106,7 +106,7 @@ def latex_table_certified_accuracy(outfile: str, radius_start: float, radius_sto f.write("& $r = {:.3}$".format(radius)) f.write("\\\\\n") - f.write("\midrule\n") + f.write(r"\midrule\n") for i, method in enumerate(methods): f.write(method.legend) @@ -153,54 +153,54 @@ def markdown_table_certified_accuracy(outfile: str, radius_start: float, radius_ if __name__ == "__main__": latex_table_certified_accuracy( "analysis/latex/vary_noise_cifar10", 0.25, 1.5, 0.25, [ - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"$\sigma = 0.12$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"), ]) markdown_table_certified_accuracy( "analysis/markdown/vary_noise_cifar10", 0.25, 1.5, 0.25, [ - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "σ = 0.12"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "σ = 0.25"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "σ = 0.50"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "σ = 1.00"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"σ = 0.12"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"σ = 0.25"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"σ = 0.50"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"σ = 1.00"), ]) latex_table_certified_accuracy( "analysis/latex/vary_noise_imagenet", 0.5, 3.0, 0.5, [ - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"), ]) markdown_table_certified_accuracy( "analysis/markdown/vary_noise_imagenet", 0.5, 3.0, 0.5, [ - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "σ = 0.25"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "σ = 0.50"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "σ = 1.00"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"σ = 0.25"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"σ = 0.50"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"σ = 1.00"), ]) plot_certified_accuracy( - "analysis/plots/vary_noise_cifar10", "CIFAR-10, vary $\sigma$", 1.5, [ - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), "$\sigma = 0.12$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), + "analysis/plots/vary_noise_cifar10", r"CIFAR-10, vary $\sigma$", 1.5, [ + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.12/test/sigma_0.12"), r"$\sigma = 0.12$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"), ]) plot_certified_accuracy( - "analysis/plots/vary_train_noise_cifar_050", "CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [ - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"), + "analysis/plots/vary_train_noise_cifar_050", r"CIFAR-10, vary train noise, $\sigma=0.5$", 1.5, [ + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.25/test/sigma_0.50"), r"train $\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_0.50/test/sigma_0.50"), r"train $\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/cifar10/resnet110/noise_1.00/test/sigma_0.50"), r"train $\sigma = 1.00$"), ]) plot_certified_accuracy( - "analysis/plots/vary_train_noise_imagenet_050", "ImageNet, vary train noise, $\sigma=0.5$", 1.5, [ - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), "train $\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "train $\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), "train $\sigma = 1.00$"), + "analysis/plots/vary_train_noise_imagenet_050", r"ImageNet, vary train noise, $\sigma=0.5$", 1.5, [ + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.50"), r"train $\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"train $\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_0.50"), r"train $\sigma = 1.00$"), ]) plot_certified_accuracy( - "analysis/plots/vary_noise_imagenet", "ImageNet, vary $\sigma$", 4, [ - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), "$\sigma = 0.25$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), "$\sigma = 0.50$"), - Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), "$\sigma = 1.00$"), + "analysis/plots/vary_noise_imagenet", r"ImageNet, vary $\sigma$", 4, [ + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.25/test/sigma_0.25"), r"$\sigma = 0.25$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_0.50/test/sigma_0.50"), r"$\sigma = 0.50$"), + Line(ApproximateAccuracy("data/certify/imagenet/resnet50/noise_1.00/test/sigma_1.00"), r"$\sigma = 1.00$"), ]) plot_certified_accuracy( "analysis/plots/high_prob", "Approximate vs. High-Probability", 2.0, [ diff --git a/code/datasets.py b/code/datasets.py index dc34462..19c42c5 100644 --- a/code/datasets.py +++ b/code/datasets.py @@ -70,7 +70,7 @@ def _imagenet(split: str) -> Dataset: elif split == "test": subdir = os.path.join(dir, "val") transform = transforms.Compose([ - transforms.Scale(256), + transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor() ]) diff --git a/code/train_utils.py b/code/train_utils.py index 123ae8c..d8cd859 100644 --- a/code/train_utils.py +++ b/code/train_utils.py @@ -30,7 +30,7 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5ec3925 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,67 @@ +ansicon==1.89.0 +asttokens==3.0.0 +blessed==1.21.0 +colorama==0.4.6 +comm==0.2.2 +contourpy==1.3.3 +cycler==0.12.1 +debugpy==1.8.14 +decorator==5.2.1 +executing==2.2.0 +filelock==3.13.1 +fonttools==4.59.1 +fsspec==2024.6.1 +gpustat==1.1.1 +ipykernel==6.29.5 +ipython==9.2.0 +ipython_pygments_lexers==1.1.1 +jedi==0.19.2 +Jinja2==3.1.4 +jinxed==1.3.0 +joblib==1.5.1 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +kiwisolver==1.4.9 +MarkupSafe==2.1.5 +matplotlib==3.10.5 +matplotlib-inline==0.1.7 +mpmath==1.3.0 +nest-asyncio==1.6.0 +networkx==3.3 +numpy==1.26.4 +nvidia-ml-py==13.580.65 +opencv-python==4.11.0.86 +packaging==25.0 +pandas==2.3.2 +parso==0.8.4 +patsy==1.0.1 +pillow==11.0.0 +platformdirs==4.3.8 +prompt_toolkit==3.0.51 +psutil==7.0.0 +pure_eval==0.2.3 +Pygments==2.19.1 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytz==2025.2 +pywin32==310 +pyzmq==26.4.0 +scikit-learn==1.7.0 +scipy==1.16.0 +seaborn==0.13.2 +setGPU==0.0.7 +six==1.17.0 +skorch==1.1.0 +stack-data==0.6.3 +statsmodels==0.14.5 +sympy==1.13.3 +tabulate==0.9.0 +threadpoolctl==3.6.0 +torch==2.8.0+cu126 +torchvision==0.23.0+cu126 +tornado==6.4.2 +tqdm==4.67.1 +traitlets==5.14.3 +typing_extensions==4.13.2 +tzdata==2025.2 +wcwidth==0.2.13