|
99 | 99 | _unshard_params,
|
100 | 100 | _unshard_params_recurse,
|
101 | 101 | )
|
102 |
| -from .wrap import ModuleWrapPolicy |
| 102 | +from .wrap import CustomPolicy, ModuleWrapPolicy |
103 | 103 |
|
104 | 104 |
|
105 | 105 | __all__ = [
|
@@ -261,7 +261,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
261 | 261 | This configures CPU offloading. If this is set to ``None``, then
|
262 | 262 | no CPU offloading happens. See :class:`CPUOffload` for details.
|
263 | 263 | (Default: ``None``)
|
264 |
| - auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy]]): |
| 264 | + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): |
265 | 265 | This specifies a policy to apply FSDP to submodules of ``module``,
|
266 | 266 | which is needed for communication and computation overlap and thus
|
267 | 267 | affects performance. If ``None``, then FSDP only applies to
|
@@ -411,7 +411,9 @@ def __init__(
|
411 | 411 | process_group: ProcessGroupType = None,
|
412 | 412 | sharding_strategy: Optional[ShardingStrategy] = None,
|
413 | 413 | cpu_offload: Optional[CPUOffload] = None,
|
414 |
| - auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy]] = None, |
| 414 | + auto_wrap_policy: Optional[ |
| 415 | + Union[Callable, ModuleWrapPolicy, CustomPolicy] |
| 416 | + ] = None, |
415 | 417 | backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
|
416 | 418 | mixed_precision: Optional[MixedPrecision] = None,
|
417 | 419 | ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
|
0 commit comments