diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 266648ce7613..f53ec84f3c9f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -34,7 +34,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask): +def prepare_mask_and_masked_image(image, mask, height, width): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -62,6 +62,13 @@ def prepare_mask_and_masked_image(image, mask): tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") @@ -109,8 +116,9 @@ def prepare_mask_and_masked_image(image, mask): # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -124,6 +132,7 @@ def prepare_mask_and_masked_image(image, mask): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): @@ -787,12 +796,6 @@ def __call__( negative_prompt_embeds, ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask_image is None: - raise ValueError("`mask_image` input cannot be undefined.") - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -818,8 +821,8 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 20977c346ecc..097093f2427e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -300,6 +300,25 @@ def test_inpaint_compile(self): assert np.abs(expected_slice - image_slice).max() < 1e-4 assert np.abs(expected_slice - image_slice).max() < 1e-3 + def test_stable_diffusion_inpaint_pil_input_resolution_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input image to a random size (one that would cause a tensor mismatch error) + inputs['image'] = inputs['image'].resize((127,127)) + inputs['mask_image'] = inputs['mask_image'].resize((127,127)) + inputs['height'] = 128 + inputs['width'] = 128 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, inputs['height'], inputs['width'], 3) + @nightly @require_torch_gpu @@ -397,12 +416,13 @@ def test_inpaint_dpm(self): class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): def test_pil_inputs(self): - im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im = Image.fromarray(im) - mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) @@ -410,8 +430,8 @@ def test_pil_inputs(self): self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_masked.ndim, 4) - self.assertEqual(t_mask.shape, (1, 1, 32, 32)) - self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + self.assertEqual(t_mask.shape, (1, 1, height, width)) + self.assertEqual(t_masked.shape, (1, 3, height, width)) self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32) @@ -424,86 +444,100 @@ def test_pil_inputs(self): self.assertTrue(t_mask.sum() > 0.0) def test_np_inputs(self): - im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + + im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im_pil = Image.fromarray(im_np) - mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5 mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) - t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) def test_torch_3D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_3D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_4D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_batch_4D_3D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5 im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -511,14 +545,16 @@ def test_torch_batch_4D_3D(self): self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_batch_4D_4D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5 im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -526,39 +562,47 @@ def test_torch_batch_4D_4D(self): self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_shape_mismatch(self): + height, width = 32, 32 + # test height and width with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width) def test_type_mismatch(self): + height, width = 32, 32 + # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width) # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width) def test_channels_first(self): + height, width = 32, 32 + # test channels first for 3D tensors with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width) def test_tensor_range(self): + height, width = 32, 32 + # test im <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width) # test im >= -1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width) # test mask <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width) # test mask >= 0 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)