@@ -397,21 +397,22 @@ def test_inpaint_dpm(self):
397
397
398
398
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests (unittest .TestCase ):
399
399
def test_pil_inputs (self ):
400
- im = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
400
+ height , width = 32 , 32
401
+ im = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
401
402
im = Image .fromarray (im )
402
- mask = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
403
+ mask = np .random .randint (0 , 255 , (height , width ), dtype = np .uint8 ) > 127.5
403
404
mask = Image .fromarray ((mask * 255 ).astype (np .uint8 ))
404
405
405
- t_mask , t_masked = prepare_mask_and_masked_image (im , mask )
406
+ t_mask , t_masked = prepare_mask_and_masked_image (im , mask , height , width )
406
407
407
408
self .assertTrue (isinstance (t_mask , torch .Tensor ))
408
409
self .assertTrue (isinstance (t_masked , torch .Tensor ))
409
410
410
411
self .assertEqual (t_mask .ndim , 4 )
411
412
self .assertEqual (t_masked .ndim , 4 )
412
413
413
- self .assertEqual (t_mask .shape , (1 , 1 , 32 , 32 ))
414
- self .assertEqual (t_masked .shape , (1 , 3 , 32 , 32 ))
414
+ self .assertEqual (t_mask .shape , (1 , 1 , height , width ))
415
+ self .assertEqual (t_masked .shape , (1 , 3 , height , width ))
415
416
416
417
self .assertTrue (t_mask .dtype == torch .float32 )
417
418
self .assertTrue (t_masked .dtype == torch .float32 )
@@ -424,141 +425,165 @@ def test_pil_inputs(self):
424
425
self .assertTrue (t_mask .sum () > 0.0 )
425
426
426
427
def test_np_inputs (self ):
427
- im_np = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
428
+ height , width = 32 , 32
429
+
430
+ im_np = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
428
431
im_pil = Image .fromarray (im_np )
429
- mask_np = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
432
+ mask_np = np .random .randint (0 , 255 , (height , width , ), dtype = np .uint8 ) > 127.5
430
433
mask_pil = Image .fromarray ((mask_np * 255 ).astype (np .uint8 ))
431
434
432
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
433
- t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil )
435
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
436
+ t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil , height , width )
434
437
435
438
self .assertTrue ((t_mask_np == t_mask_pil ).all ())
436
439
self .assertTrue ((t_masked_np == t_masked_pil ).all ())
437
440
438
441
def test_torch_3D_2D_inputs (self ):
439
- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
440
- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
442
+ height , width = 32 , 32
443
+
444
+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
445
+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
441
446
im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
442
447
mask_np = mask_tensor .numpy ()
443
448
444
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
445
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
449
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
450
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
446
451
447
452
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
448
453
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
449
454
450
455
def test_torch_3D_3D_inputs (self ):
451
- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
452
- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
456
+ height , width = 32 , 32
457
+
458
+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
459
+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
453
460
im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
454
461
mask_np = mask_tensor .numpy ()[0 ]
455
462
456
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
457
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
463
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
464
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
458
465
459
466
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
460
467
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
461
468
462
469
def test_torch_4D_2D_inputs (self ):
463
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
464
- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
470
+ height , width = 32 , 32
471
+
472
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
473
+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
465
474
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
466
475
mask_np = mask_tensor .numpy ()
467
476
468
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
469
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
477
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
478
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
470
479
471
480
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
472
481
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
473
482
474
483
def test_torch_4D_3D_inputs (self ):
475
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
476
- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
484
+ height , width = 32 , 32
485
+
486
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
487
+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
477
488
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
478
489
mask_np = mask_tensor .numpy ()[0 ]
479
490
480
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
481
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
491
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
492
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
482
493
483
494
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
484
495
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
485
496
486
497
def test_torch_4D_4D_inputs (self ):
487
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
488
- mask_tensor = torch .randint (0 , 255 , (1 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
498
+ height , width = 32 , 32
499
+
500
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
501
+ mask_tensor = torch .randint (0 , 255 , (1 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
489
502
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
490
503
mask_np = mask_tensor .numpy ()[0 ][0 ]
491
504
492
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
493
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
505
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
506
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
494
507
495
508
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
496
509
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
497
510
498
511
def test_torch_batch_4D_3D (self ):
499
- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
500
- mask_tensor = torch .randint (0 , 255 , (2 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
512
+ height , width = 32 , 32
513
+
514
+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
515
+ mask_tensor = torch .randint (0 , 255 , (2 , height , width ,), dtype = torch .uint8 ) > 127.5
501
516
502
517
im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
503
518
mask_nps = [mask .numpy () for mask in mask_tensor ]
504
519
505
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
506
- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
520
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
521
+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
507
522
t_mask_np = torch .cat ([n [0 ] for n in nps ])
508
523
t_masked_np = torch .cat ([n [1 ] for n in nps ])
509
524
510
525
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
511
526
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
512
527
513
528
def test_torch_batch_4D_4D (self ):
514
- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
515
- mask_tensor = torch .randint (0 , 255 , (2 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
529
+ height , width = 32 , 32
530
+
531
+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
532
+ mask_tensor = torch .randint (0 , 255 , (2 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
516
533
517
534
im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
518
535
mask_nps = [mask .numpy ()[0 ] for mask in mask_tensor ]
519
536
520
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
521
- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
537
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
538
+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
522
539
t_mask_np = torch .cat ([n [0 ] for n in nps ])
523
540
t_masked_np = torch .cat ([n [1 ] for n in nps ])
524
541
525
542
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
526
543
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
527
544
528
545
def test_shape_mismatch (self ):
546
+ height , width = 32 , 32
547
+
529
548
# test height and width
530
549
with self .assertRaises (AssertionError ):
531
- prepare_mask_and_masked_image (torch .randn (3 , 32 , 32 ), torch .randn (64 , 64 ))
550
+ prepare_mask_and_masked_image (torch .randn (3 , height , width , ), torch .randn (64 , 64 ), height , width )
532
551
# test batch dim
533
552
with self .assertRaises (AssertionError ):
534
- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 64 , 64 ))
553
+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 64 , 64 ), height , width )
535
554
# test batch dim
536
555
with self .assertRaises (AssertionError ):
537
- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 1 , 64 , 64 ))
556
+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 1 , 64 , 64 ), height , width )
538
557
539
558
def test_type_mismatch (self ):
559
+ height , width = 32 , 32
560
+
540
561
# test tensors-only
541
562
with self .assertRaises (TypeError ):
542
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .rand (3 , 32 , 32 ).numpy ())
563
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .rand (3 , height , width , ).numpy (), height , width )
543
564
# test tensors-only
544
565
with self .assertRaises (TypeError ):
545
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ).numpy (), torch .rand (3 , 32 , 32 ) )
566
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ).numpy (), torch .rand (3 , height , width ,), height , width )
546
567
547
568
def test_channels_first (self ):
569
+ height , width = 32 , 32
570
+
548
571
# test channels first for 3D tensors
549
572
with self .assertRaises (AssertionError ):
550
- prepare_mask_and_masked_image (torch .rand (32 , 32 , 3 ), torch .rand (3 , 32 , 32 ) )
573
+ prepare_mask_and_masked_image (torch .rand (height , width , 3 ), torch .rand (3 , height , width ,), height , width )
551
574
552
575
def test_tensor_range (self ):
576
+ height , width = 32 , 32
577
+
553
578
# test im <= 1
554
579
with self .assertRaises (ValueError ):
555
- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * 2 , torch .rand (32 , 32 ) )
580
+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * 2 , torch .rand (height , width ,), height , width )
556
581
# test im >= -1
557
582
with self .assertRaises (ValueError ):
558
- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * (- 2 ), torch .rand (32 , 32 ) )
583
+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * (- 2 ), torch .rand (height , width ,), height , width )
559
584
# test mask <= 1
560
585
with self .assertRaises (ValueError ):
561
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * 2 )
586
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * 2 , height , width )
562
587
# test mask >= 0
563
588
with self .assertRaises (ValueError ):
564
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * - 1 )
589
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * - 1 , height , width )
0 commit comments