Skip to content

Commit 4766b68

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Add out variants for softmax and log_softmax
Pull Request resolved: #75833 Approved by: https://github.com/ngimel
1 parent 6b6d09c commit 4766b68

File tree

3 files changed

+88
-4
lines changed

3 files changed

+88
-4
lines changed

aten/src/ATen/native/SoftMax.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,43 @@ Tensor softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarTyp
438438
return result;
439439
}
440440

441+
Tensor& softmax_out(
442+
const Tensor& input_,
443+
const int64_t dim_,
444+
c10::optional<ScalarType> dtype,
445+
Tensor& output_) {
446+
Tensor output_temp;
447+
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
448+
dtype == ScalarType::Float) {
449+
if (!output_.is_contiguous()) {
450+
auto options =
451+
TensorOptions().dtype(output_.dtype()).device(output_.device());
452+
output_temp = at::empty(output_.sizes(), options);
453+
at::_softmax_out(output_temp, input_, dim_, true);
454+
} else {
455+
at::_softmax_out(output_, input_, dim_, true);
456+
}
457+
} else {
458+
Tensor converted =
459+
dtype.has_value() ? input_.toType(dtype.value()) : input_;
460+
if (!output_.is_contiguous()) {
461+
auto options =
462+
TensorOptions().dtype(output_.dtype()).device(output_.device());
463+
output_temp = at::empty(output_.sizes(), options);
464+
at::_softmax_out(output_temp, converted, dim_, false);
465+
} else {
466+
at::_softmax_out(output_, converted, dim_, false);
467+
}
468+
}
469+
470+
if (!output_.is_contiguous()) {
471+
output_.resize_(output_temp.sizes());
472+
output_.copy_(output_temp);
473+
}
474+
475+
return output_;
476+
}
477+
441478
// special_softmax, alias for softmax
442479
Tensor special_softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarType> dtype) {
443480
return at::softmax(input_, dim_, dtype);
@@ -466,6 +503,43 @@ Tensor log_softmax(const Tensor& input_, const int64_t dim_, c10::optional<Scala
466503
return result;
467504
}
468505

506+
Tensor& log_softmax_out(
507+
const Tensor& input_,
508+
const int64_t dim_,
509+
c10::optional<ScalarType> dtype,
510+
Tensor& output_) {
511+
Tensor output_temp;
512+
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
513+
dtype == ScalarType::Float) {
514+
if (!output_.is_contiguous()) {
515+
auto options =
516+
TensorOptions().dtype(output_.dtype()).device(output_.device());
517+
output_temp = at::empty(output_.sizes(), options);
518+
at::_log_softmax_out(output_temp, input_, dim_, true);
519+
} else {
520+
at::_log_softmax_out(output_, input_, dim_, true);
521+
}
522+
} else {
523+
Tensor converted =
524+
dtype.has_value() ? input_.toType(dtype.value()) : input_;
525+
if (!output_.is_contiguous()) {
526+
auto options =
527+
TensorOptions().dtype(output_.dtype()).device(output_.device());
528+
output_temp = at::empty(output_.sizes(), options);
529+
at::_log_softmax_out(output_temp, converted, dim_, false);
530+
} else {
531+
at::_log_softmax_out(output_, converted, dim_, false);
532+
}
533+
}
534+
535+
if (!output_.is_contiguous()) {
536+
output_.resize_(output_temp.sizes());
537+
output_.copy_(output_temp);
538+
}
539+
540+
return output_;
541+
}
542+
469543
Tensor special_log_softmax(const Tensor& input, const int64_t dim, c10::optional<ScalarType> dtype) {
470544
return at::log_softmax(input, dim, dtype);
471545
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2821,6 +2821,11 @@
28212821
- func: log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
28222822
variants: function, method
28232823

2824+
- func: log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
2825+
variants: function
2826+
dispatch:
2827+
CompositeExplicitAutograd: log_softmax_out
2828+
28242829
- func: log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
28252830
variants: function, method
28262831

@@ -4131,6 +4136,11 @@
41314136
- func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
41324137
variants: function, method
41334138

4139+
- func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
4140+
variants: function
4141+
dispatch:
4142+
CompositeExplicitAutograd: softmax_out
4143+
41344144
- func: softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
41354145
variants: function, method
41364146

torch/testing/_internal/common_methods_invocations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11275,7 +11275,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1127511275
assert_jit_shape_analysis=True,
1127611276
assert_autodiffed=True,
1127711277
supports_forward_ad=True,
11278-
supports_out=False),
11278+
supports_out=True),
1127911279
OpInfo('softmax',
1128011280
aliases=('special.softmax', 'nn.functional.softmax',),
1128111281
variant_test_name="with_dtype",
@@ -11284,7 +11284,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1128411284
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
1128511285
assert_autodiffed=True,
1128611286
supports_forward_ad=True,
11287-
supports_out=False),
11287+
supports_out=True),
1128811288
# `softmin` supports different dtypes based on whether `dtype` argument,
1128911289
# is passed or not. Hence two OpInfo entries, one with dtype and other without.
1129011290
# https://github.com/pytorch/pytorch/issues/68752
@@ -15445,7 +15445,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1544515445
OpInfo(
1544615446
'log_softmax',
1544715447
aliases=('special.log_softmax', 'nn.functional.log_softmax'),
15448-
supports_out=False,
15448+
supports_out=True,
1544915449
dtypes=floating_types_and(torch.bfloat16),
1545015450
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
1545115451
sample_inputs_func=sample_inputs_softmax_variant,
@@ -15455,7 +15455,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1545515455
'log_softmax',
1545615456
variant_test_name='dtype',
1545715457
aliases=('special.log_softmax', 'nn.functional.log_softmax'),
15458-
supports_out=False,
15458+
supports_out=True,
1545915459
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
1546015460
sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
1546115461
supports_forward_ad=True,

0 commit comments

Comments
 (0)