-
Notifications
You must be signed in to change notification settings - Fork 364
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Lines 1030 to 1037 in a662411
allowed_casts = { | |
torch.float, | |
torch.int32, | |
torch.int64, | |
torch.bool, | |
torch.int8, | |
torch.float16, | |
} |
Is there a specific reason why torch.bfloat16
is not included in the allowed_casts
set within the to_copy_dtype_validator
function?
Plus, this causes graph partitioning when performing a aten.ops._to_copy
operation to torch.bfloat16
. I'm wondering if this could potentially impact performance.
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested