Skip to content

FLUX.1-dev dreambooth train on multigpu #9278

Closed
@MehmetcanTozlu

Description

@MehmetcanTozlu

Discussed in #9277

Originally posted by MehmetcanTozlu August 26, 2024
I have 4 A6000 Ada Generation cards. but I get an error when training Flux.1-dev with dreambooth(train_dreambooth_flux.py). can anyone help?

pip list:
absl-py 2.1.0
accelerate 0.33.0
bitsandbytes 0.43.3
certifi 2024.7.4
charset-normalizer 3.3.2
click 8.1.7
cmake 3.25.0
diffusers 0.31.0.dev0
docker-pycreds 0.4.0
filelock 3.15.4
fsspec 2024.6.1
ftfy 6.2.3
gitdb 4.0.11
GitPython 3.1.43
grpcio 1.66.0
huggingface-hub 0.24.6
idna 3.8
importlib_metadata 8.4.0
Jinja2 3.1.4
lit 15.0.7
Markdown 3.7
MarkupSafe 2.1.5
mpmath 1.3.0
networkx 3.3
numpy 1.26.4
nvidia-cublas-cu11 11.11.3.6
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.7.0.84
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.3.0.86
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.5.86
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.19.3
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.20
nvidia-nvtx-cu11 11.8.86
nvidia-nvtx-cu12 12.1.105
packaging 24.1
peft 0.12.0
pillow 10.4.0
pip 22.0.2
platformdirs 4.2.2
protobuf 5.27.3
psutil 6.0.0
PyYAML 6.0.2
regex 2024.7.24
requests 2.32.3
safetensors 0.4.4
sentencepiece 0.2.0
sentry-sdk 2.13.0
setproctitle 1.3.3
setuptools 59.6.0
six 1.16.0
smmap 5.0.1
sympy 1.13.2
tensorboard 2.17.1
tensorboard-data-server 0.7.2
tokenizers 0.19.1
torch 2.2.1+cu118
torchaudio 2.2.1+cu118
torchvision 0.17.1+cu118
tqdm 4.66.5
transformers 4.44.2
triton 2.2.0
typing_extensions 4.12.2
urllib3 2.2.2
wandb 0.17.7
wcwidth 0.2.13
Werkzeug 3.0.4
zipp 3.20.0

nvidia-smi:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX 6000 Ada Gene... On | 00000000:00:10.0 Off | Off |
| 30% 32C P8 28W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA RTX 6000 Ada Gene... On | 00000000:00:11.0 Off | Off |
| 30% 32C P8 24W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA RTX 6000 Ada Gene... On | 00000000:00:1B.0 Off | Off |
| 30% 30C P8 22W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA RTX 6000 Ada Gene... On | 00000000:00:1C.0 Off | Off |
| 30% 28C P8 24W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+

run command:
accelerate launch --num_processes=4 train_dreambooth_flux.py --pretrained_model_name_or_path="/mnt/s1/flux_dev/FLUX.1-dev" --instance_data_dir="/home/wrusr/flux_train_env/flux_dreambooth_train/dreambooth-datasets_/style_spiderverse" --caption_column="/home/wrusr/flux_train_env/flux_dreambooth_train/dreambooth-datasets_/atext" --instance_prompt="A photo of spdrvrs" --output_dir="/home/wrusr/flux_train_env/flux_dreambooth_train" --resolution=024 --train_batch_size=1 --mixed_precision="fp16" --max_train_steps=500 --revision="fp16" --gradient_accumulation_steps=4 --learning_rate=1e-4 --report_to="wandb" --lr_scheduler="constant" --lr_warmup_steps=0 --seed="0" --push_to_hub

error:
08/26/2024 11:46:54 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 4
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

08/26/2024 11:46:54 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 4
Process index: 3
Local process index: 3
Device: cuda:3

Mixed precision type: fp16

08/26/2024 11:46:54 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 4
Process index: 2
Local process index: 2
Device: cuda:2

Mixed precision type: fp16

08/26/2024 11:46:54 - INFO - main - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 4
Process index: 1
Local process index: 1
Device: cuda:1

Mixed precision type: fp16

You set add_prefix_space. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.75s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.91s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.92s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.92s/it]
{'axes_dims_rope'} was not found in config. Values will be initialized to default values.
[2024-08-26 11:48:25,610] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 32699 closing signal SIGTERM
[2024-08-26 11:48:25,733] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 32701 closing signal SIGTERM
[2024-08-26 11:48:25,733] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 32702 closing signal SIGTERM
[2024-08-26 11:48:28,939] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -9) local_rank: 1 (pid: 32700) of binary: /home/wrusr/flux_train_env/bin/python3.10
Traceback (most recent call last):
File "/home/wrusr/flux_train_env/bin/accelerate", line 8, in
sys.exit(main())
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1097, in launch_command
multi_gpu_launcher(args)
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/accelerate/commands/launch.py", line 734, in multi_gpu_launcher
distrib_run.run(args)
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
elastic_launch(
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/wrusr/flux_train_env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

train_dreambooth_flux.py FAILED

Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2024-08-26_11:48:25
host : wr-dev-train-1
rank : 1 (local_rank: 1)
exitcode : -9 (pid: 32700)
error_file: <N/A>
traceback : Signal 9 (SIGKILL) received by PID 32700
======================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions