Skip to content

Commit c79de88

Browse files
author
Rupert Menneer
committed
Added a resolution test to StableDiffusionInpaintPipelineSlowTests
this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width
1 parent c4f3210 commit c79de88

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,25 @@ def test_inpaint_compile(self):
300300
assert np.abs(expected_slice - image_slice).max() < 1e-4
301301
assert np.abs(expected_slice - image_slice).max() < 1e-3
302302

303+
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
304+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
305+
"runwayml/stable-diffusion-inpainting", safety_checker=None
306+
)
307+
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
308+
pipe.to(torch_device)
309+
pipe.set_progress_bar_config(disable=None)
310+
pipe.enable_attention_slicing()
311+
312+
inputs = self.get_inputs(torch_device)
313+
# change input image to a random size (one that would cause a tensor mismatch error)
314+
inputs['image'] = inputs['image'].resize((127,127))
315+
inputs['mask_image'] = inputs['mask_image'].resize((127,127))
316+
inputs['height'] = 128
317+
inputs['width'] = 128
318+
image = pipe(**inputs).images
319+
# verify that the returned image has the same height and width as the input height and width
320+
assert image.shape == (1, inputs['height'], inputs['width'], 3)
321+
303322

304323
@nightly
305324
@require_torch_gpu

0 commit comments

Comments
 (0)