Skip to content

Commit c4f3210

Browse files
author
Rupert Menneer
committed
Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests
Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution
1 parent 49d8702 commit c4f3210

File tree

1 file changed

+72
-47
lines changed

1 file changed

+72
-47
lines changed

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -397,21 +397,22 @@ def test_inpaint_dpm(self):
397397

398398
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
399399
def test_pil_inputs(self):
400-
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
400+
height, width = 32, 32
401+
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
401402
im = Image.fromarray(im)
402-
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
403+
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
403404
mask = Image.fromarray((mask * 255).astype(np.uint8))
404405

405-
t_mask, t_masked = prepare_mask_and_masked_image(im, mask)
406+
t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)
406407

407408
self.assertTrue(isinstance(t_mask, torch.Tensor))
408409
self.assertTrue(isinstance(t_masked, torch.Tensor))
409410

410411
self.assertEqual(t_mask.ndim, 4)
411412
self.assertEqual(t_masked.ndim, 4)
412413

413-
self.assertEqual(t_mask.shape, (1, 1, 32, 32))
414-
self.assertEqual(t_masked.shape, (1, 3, 32, 32))
414+
self.assertEqual(t_mask.shape, (1, 1, height, width))
415+
self.assertEqual(t_masked.shape, (1, 3, height, width))
415416

416417
self.assertTrue(t_mask.dtype == torch.float32)
417418
self.assertTrue(t_masked.dtype == torch.float32)
@@ -424,141 +425,165 @@ def test_pil_inputs(self):
424425
self.assertTrue(t_mask.sum() > 0.0)
425426

426427
def test_np_inputs(self):
427-
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
428+
height, width = 32, 32
429+
430+
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
428431
im_pil = Image.fromarray(im_np)
429-
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
432+
mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5
430433
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
431434

432-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
433-
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)
435+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
436+
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)
434437

435438
self.assertTrue((t_mask_np == t_mask_pil).all())
436439
self.assertTrue((t_masked_np == t_masked_pil).all())
437440

438441
def test_torch_3D_2D_inputs(self):
439-
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
440-
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
442+
height, width = 32, 32
443+
444+
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
445+
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
441446
im_np = im_tensor.numpy().transpose(1, 2, 0)
442447
mask_np = mask_tensor.numpy()
443448

444-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
445-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
449+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
450+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
446451

447452
self.assertTrue((t_mask_tensor == t_mask_np).all())
448453
self.assertTrue((t_masked_tensor == t_masked_np).all())
449454

450455
def test_torch_3D_3D_inputs(self):
451-
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
452-
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
456+
height, width = 32, 32
457+
458+
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
459+
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
453460
im_np = im_tensor.numpy().transpose(1, 2, 0)
454461
mask_np = mask_tensor.numpy()[0]
455462

456-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
457-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
463+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
464+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
458465

459466
self.assertTrue((t_mask_tensor == t_mask_np).all())
460467
self.assertTrue((t_masked_tensor == t_masked_np).all())
461468

462469
def test_torch_4D_2D_inputs(self):
463-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
464-
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
470+
height, width = 32, 32
471+
472+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
473+
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
465474
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
466475
mask_np = mask_tensor.numpy()
467476

468-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
469-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
477+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
478+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
470479

471480
self.assertTrue((t_mask_tensor == t_mask_np).all())
472481
self.assertTrue((t_masked_tensor == t_masked_np).all())
473482

474483
def test_torch_4D_3D_inputs(self):
475-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
476-
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
484+
height, width = 32, 32
485+
486+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
487+
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
477488
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
478489
mask_np = mask_tensor.numpy()[0]
479490

480-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
481-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
491+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
492+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
482493

483494
self.assertTrue((t_mask_tensor == t_mask_np).all())
484495
self.assertTrue((t_masked_tensor == t_masked_np).all())
485496

486497
def test_torch_4D_4D_inputs(self):
487-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
488-
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
498+
height, width = 32, 32
499+
500+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
501+
mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5
489502
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
490503
mask_np = mask_tensor.numpy()[0][0]
491504

492-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
493-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
505+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
506+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
494507

495508
self.assertTrue((t_mask_tensor == t_mask_np).all())
496509
self.assertTrue((t_masked_tensor == t_masked_np).all())
497510

498511
def test_torch_batch_4D_3D(self):
499-
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
500-
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
512+
height, width = 32, 32
513+
514+
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
515+
mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5
501516

502517
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
503518
mask_nps = [mask.numpy() for mask in mask_tensor]
504519

505-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
506-
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
520+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
521+
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
507522
t_mask_np = torch.cat([n[0] for n in nps])
508523
t_masked_np = torch.cat([n[1] for n in nps])
509524

510525
self.assertTrue((t_mask_tensor == t_mask_np).all())
511526
self.assertTrue((t_masked_tensor == t_masked_np).all())
512527

513528
def test_torch_batch_4D_4D(self):
514-
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
515-
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
529+
height, width = 32, 32
530+
531+
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
532+
mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5
516533

517534
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
518535
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
519536

520-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
521-
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
537+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
538+
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
522539
t_mask_np = torch.cat([n[0] for n in nps])
523540
t_masked_np = torch.cat([n[1] for n in nps])
524541

525542
self.assertTrue((t_mask_tensor == t_mask_np).all())
526543
self.assertTrue((t_masked_tensor == t_masked_np).all())
527544

528545
def test_shape_mismatch(self):
546+
height, width = 32, 32
547+
529548
# test height and width
530549
with self.assertRaises(AssertionError):
531-
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
550+
prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width)
532551
# test batch dim
533552
with self.assertRaises(AssertionError):
534-
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
553+
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width)
535554
# test batch dim
536555
with self.assertRaises(AssertionError):
537-
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))
556+
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width)
538557

539558
def test_type_mismatch(self):
559+
height, width = 32, 32
560+
540561
# test tensors-only
541562
with self.assertRaises(TypeError):
542-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
563+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width)
543564
# test tensors-only
544565
with self.assertRaises(TypeError):
545-
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))
566+
prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width)
546567

547568
def test_channels_first(self):
569+
height, width = 32, 32
570+
548571
# test channels first for 3D tensors
549572
with self.assertRaises(AssertionError):
550-
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))
573+
prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width)
551574

552575
def test_tensor_range(self):
576+
height, width = 32, 32
577+
553578
# test im <= 1
554579
with self.assertRaises(ValueError):
555-
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
580+
prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width)
556581
# test im >= -1
557582
with self.assertRaises(ValueError):
558-
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
583+
prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width)
559584
# test mask <= 1
560585
with self.assertRaises(ValueError):
561-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
586+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width)
562587
# test mask >= 0
563588
with self.assertRaises(ValueError):
564-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)
589+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)

0 commit comments

Comments
 (0)