Skip to content

Commit 95c59b3

Browse files
mvpatel2000pytorchmergebot
authored andcommitted
Update fully_sharded_data_parallel to fix typing (#110545)
Fixes typing so that linter does not complain when using CustomPolicy. Pull Request resolved: #110545 Approved by: https://github.com/awgu, https://github.com/Skylion007
1 parent 0daa7d4 commit 95c59b3

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
_unshard_params,
100100
_unshard_params_recurse,
101101
)
102-
from .wrap import ModuleWrapPolicy
102+
from .wrap import CustomPolicy, ModuleWrapPolicy
103103

104104

105105
__all__ = [
@@ -261,7 +261,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
261261
This configures CPU offloading. If this is set to ``None``, then
262262
no CPU offloading happens. See :class:`CPUOffload` for details.
263263
(Default: ``None``)
264-
auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy]]):
264+
auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]):
265265
This specifies a policy to apply FSDP to submodules of ``module``,
266266
which is needed for communication and computation overlap and thus
267267
affects performance. If ``None``, then FSDP only applies to
@@ -411,7 +411,9 @@ def __init__(
411411
process_group: ProcessGroupType = None,
412412
sharding_strategy: Optional[ShardingStrategy] = None,
413413
cpu_offload: Optional[CPUOffload] = None,
414-
auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy]] = None,
414+
auto_wrap_policy: Optional[
415+
Union[Callable, ModuleWrapPolicy, CustomPolicy]
416+
] = None,
415417
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
416418
mixed_precision: Optional[MixedPrecision] = None,
417419
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,

0 commit comments

Comments
 (0)