Skip to content

Safety checker fails to trace #275

Closed
@TobyRoseman

Description

@TobyRoseman

When using CompVis/stable-diffusion-v1-4 the safety checker fails torch.jit.trace:

(coreml_stable_diffusion) toby@Tobys-MBP-4 ml-stable-diffusion % python -m python_coreml_stable_diffusion.torch2coreml --convert-safety-checker --model-version "CompVis/stable-diffusion-v1-4" -o out
scikit-learn version 1.3.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
INFO:__main__:Initializing DiffusionPipeline with CompVis/stable-diffusion-v1-4..
Loading pipeline components...:  43%|██████████████████████████████████████████████████▌                                                                   | 3/7 [00:00<00:00, 23.57it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.58it/s]
INFO:__main__:Done. Pipeline in effect: StableDiffusionPipeline
INFO:__main__:Attention implementation in effect: AttentionImplementations.SPLIT_EINSUM
INFO:__main__:Converting safety_checker
INFO:__main__:Sample inputs spec: {'clip_input': (torch.Size([1, 3, 224, 224]), torch.float32), 'images': (torch.Size([1, 512, 512, 3]), torch.float32), 'adjustment': (torch.Size([1]), torch.float32)}
INFO:__main__:JIT tracing..
Traceback (most recent call last):
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Volumes/DevData/workspace/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 1522, in <module>
    main(args)
  File "/Volumes/DevData/workspace/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 1349, in main
    convert_safety_checker(pipe, args)
  File "/Volumes/DevData/workspace/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 1045, in convert_safety_checker
    traced_safety_checker = torch.jit.trace(
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Volumes/DevData/workspace/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 976, in forward_coreml
    pooled_output = self.vision_model(clip_input)[1]  # pooled_output
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 958, in forward
    return self.vision_model(
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 883, in forward
    hidden_states = self.embeddings(pixel_values)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 196, in forward
    patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/Users/toby/miniforge3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions