Skip to content

Commit 3990952

Browse files
alykhantejanifmassa
authored andcommitted
Refactor transforms and add transforms/functional.py (#311)
* refactor transforms and add transforms/functional.py * add __all__ to transforms.py
1 parent f9df932 commit 3990952

File tree

5 files changed

+1124
-1134
lines changed

5 files changed

+1124
-1134
lines changed

test/test_transforms.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torchvision.transforms as transforms
3+
import torchvision.transforms.functional as F
34
import unittest
45
import random
56
import numpy as np
@@ -14,7 +15,6 @@
1415
except ImportError:
1516
stats = None
1617

17-
1818
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
1919

2020

@@ -347,7 +347,7 @@ def verify_img_data(img_data, expected_output, mode):
347347
assert img.mode == mode
348348
split = img.split()
349349
for i in range(3):
350-
assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy())
350+
assert np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy())
351351

352352
img_data = torch.Tensor(3, 4, 4).uniform_()
353353
expected_output = img_data.mul(255).int().float().div(255)
@@ -391,7 +391,7 @@ def verify_img_data(img_data, expected_output, mode):
391391

392392
split = img.split()
393393
for i in range(4):
394-
assert np.allclose(expected_output[i].numpy(), transforms.to_tensor(split[i]).numpy())
394+
assert np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy())
395395

396396
img_data = torch.Tensor(4, 4, 4).uniform_()
397397
expected_output = img_data.mul(255).int().float().div(255)
@@ -491,19 +491,19 @@ def test_adjust_brightness(self):
491491
x_pil = Image.fromarray(x_np, mode='RGB')
492492

493493
# test 0
494-
y_pil = transforms.adjust_brightness(x_pil, 1)
494+
y_pil = F.adjust_brightness(x_pil, 1)
495495
y_np = np.array(y_pil)
496496
assert np.allclose(y_np, x_np)
497497

498498
# test 1
499-
y_pil = transforms.adjust_brightness(x_pil, 0.5)
499+
y_pil = F.adjust_brightness(x_pil, 0.5)
500500
y_np = np.array(y_pil)
501501
y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
502502
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
503503
assert np.allclose(y_np, y_ans)
504504

505505
# test 2
506-
y_pil = transforms.adjust_brightness(x_pil, 2)
506+
y_pil = F.adjust_brightness(x_pil, 2)
507507
y_np = np.array(y_pil)
508508
y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
509509
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
@@ -516,19 +516,19 @@ def test_adjust_contrast(self):
516516
x_pil = Image.fromarray(x_np, mode='RGB')
517517

518518
# test 0
519-
y_pil = transforms.adjust_contrast(x_pil, 1)
519+
y_pil = F.adjust_contrast(x_pil, 1)
520520
y_np = np.array(y_pil)
521521
assert np.allclose(y_np, x_np)
522522

523523
# test 1
524-
y_pil = transforms.adjust_contrast(x_pil, 0.5)
524+
y_pil = F.adjust_contrast(x_pil, 0.5)
525525
y_np = np.array(y_pil)
526526
y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
527527
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
528528
assert np.allclose(y_np, y_ans)
529529

530530
# test 2
531-
y_pil = transforms.adjust_contrast(x_pil, 2)
531+
y_pil = F.adjust_contrast(x_pil, 2)
532532
y_np = np.array(y_pil)
533533
y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
534534
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
@@ -541,19 +541,19 @@ def test_adjust_saturation(self):
541541
x_pil = Image.fromarray(x_np, mode='RGB')
542542

543543
# test 0
544-
y_pil = transforms.adjust_saturation(x_pil, 1)
544+
y_pil = F.adjust_saturation(x_pil, 1)
545545
y_np = np.array(y_pil)
546546
assert np.allclose(y_np, x_np)
547547

548548
# test 1
549-
y_pil = transforms.adjust_saturation(x_pil, 0.5)
549+
y_pil = F.adjust_saturation(x_pil, 0.5)
550550
y_np = np.array(y_pil)
551551
y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
552552
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
553553
assert np.allclose(y_np, y_ans)
554554

555555
# test 2
556-
y_pil = transforms.adjust_saturation(x_pil, 2)
556+
y_pil = F.adjust_saturation(x_pil, 2)
557557
y_np = np.array(y_pil)
558558
y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
559559
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
@@ -566,26 +566,26 @@ def test_adjust_hue(self):
566566
x_pil = Image.fromarray(x_np, mode='RGB')
567567

568568
with self.assertRaises(ValueError):
569-
transforms.adjust_hue(x_pil, -0.7)
570-
transforms.adjust_hue(x_pil, 1)
569+
F.adjust_hue(x_pil, -0.7)
570+
F.adjust_hue(x_pil, 1)
571571

572572
# test 0: almost same as x_data but not exact.
573573
# probably because hsv <-> rgb floating point ops
574-
y_pil = transforms.adjust_hue(x_pil, 0)
574+
y_pil = F.adjust_hue(x_pil, 0)
575575
y_np = np.array(y_pil)
576576
y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
577577
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
578578
assert np.allclose(y_np, y_ans)
579579

580580
# test 1
581-
y_pil = transforms.adjust_hue(x_pil, 0.25)
581+
y_pil = F.adjust_hue(x_pil, 0.25)
582582
y_np = np.array(y_pil)
583583
y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
584584
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
585585
assert np.allclose(y_np, y_ans)
586586

587587
# test 2
588-
y_pil = transforms.adjust_hue(x_pil, -0.25)
588+
y_pil = F.adjust_hue(x_pil, -0.25)
589589
y_np = np.array(y_pil)
590590
y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
591591
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
@@ -598,19 +598,19 @@ def test_adjust_gamma(self):
598598
x_pil = Image.fromarray(x_np, mode='RGB')
599599

600600
# test 0
601-
y_pil = transforms.adjust_gamma(x_pil, 1)
601+
y_pil = F.adjust_gamma(x_pil, 1)
602602
y_np = np.array(y_pil)
603603
assert np.allclose(y_np, x_np)
604604

605605
# test 1
606-
y_pil = transforms.adjust_gamma(x_pil, 0.5)
606+
y_pil = F.adjust_gamma(x_pil, 0.5)
607607
y_np = np.array(y_pil)
608608
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
609609
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
610610
assert np.allclose(y_np, y_ans)
611611

612612
# test 2
613-
y_pil = transforms.adjust_gamma(x_pil, 2)
613+
y_pil = F.adjust_gamma(x_pil, 2)
614614
y_np = np.array(y_pil)
615615
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
616616
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
@@ -623,11 +623,11 @@ def test_adjusts_L_mode(self):
623623
x_rgb = Image.fromarray(x_np, mode='RGB')
624624

625625
x_l = x_rgb.convert('L')
626-
assert transforms.adjust_brightness(x_l, 2).mode == 'L'
627-
assert transforms.adjust_saturation(x_l, 2).mode == 'L'
628-
assert transforms.adjust_contrast(x_l, 2).mode == 'L'
629-
assert transforms.adjust_hue(x_l, 0.4).mode == 'L'
630-
assert transforms.adjust_gamma(x_l, 0.5).mode == 'L'
626+
assert F.adjust_brightness(x_l, 2).mode == 'L'
627+
assert F.adjust_saturation(x_l, 2).mode == 'L'
628+
assert F.adjust_contrast(x_l, 2).mode == 'L'
629+
assert F.adjust_hue(x_l, 0.4).mode == 'L'
630+
assert F.adjust_gamma(x_l, 0.5).mode == 'L'
631631

632632
def test_color_jitter(self):
633633
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

0 commit comments

Comments
 (0)