diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 7cf19e870e..4985788808 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -39,7 +39,6 @@ from torch_tensorrt.dynamo.utils import ( CPU_DEVICE, check_module_output, - deallocate_module, get_model_device, get_torch_inputs, to_torch_device, @@ -300,7 +299,7 @@ def refit_module_weights( # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( - new_gm, settings.debug, settings.torch_executed_ops + new_gm, settings.torch_executed_ops ) if num_supported_ops == 0 or ( @@ -363,7 +362,6 @@ def refit_module_weights( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - new_weight_module.module().to(CPU_DEVICE) for name, new_submodule in new_partitioned_module.named_children(): # Refit each submodule # Extract engine from the submodule @@ -466,7 +464,6 @@ def refit_module_weights( settings=settings, weight_name_map=None, ) - deallocate_module(new_submodule) # clear EXCLUDE_WEIGHTS flag serialization_config = engine.create_serialization_config() @@ -489,8 +486,6 @@ def refit_module_weights( gc.collect() torch.cuda.empty_cache() - deallocate_module(new_partitioned_module) - if verify_output and arg_inputs is not None: new_gm.to(to_torch_device(settings.device)) if check_module_output( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index dfdc9e1c69..c0d29c41f0 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -10,7 +10,6 @@ from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple -from torch.distributed.tensor import DTensor from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( @@ -89,6 +88,11 @@ def aot_torch_tensorrt_aten_backend( logger.warning( "It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" ) + + if settings.offload_module_to_cpu: + logger.warning( + "The offload_module_to_cpu option is set, but it is being ignored since the torch_compile backend does not support this feature" + ) return _pretraced_backend(gm, sample_inputs, settings, engine_cache) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 6166e16949..5b8dd90d92 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -320,11 +320,12 @@ def test_resnet18_cpu_offload(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - assertions.assertTrue( - get_model_device(model).type == "cpu", - msg="Model should be offloaded to CPU", - ) - model.cuda() + if ir == "dynamo": + assertions.assertTrue( + get_model_device(model).type == "cpu", + msg="Model should be offloaded to CPU", + ) + model.cuda() torchtrt.save(trt_module, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 359044a2b2..90d3cc637b 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -79,11 +79,12 @@ def test_resnet18_cpu_offload(ir): } trt_mod = torchtrt.compile(model, **compile_spec) - assertions.assertTrue( - get_model_device(model).type == "cpu", - msg="Model should be offloaded to CPU", - ) - model.cuda() + if ir == "dynamo": + assertions.assertTrue( + get_model_device(model).type == "cpu", + msg="Model should be offloaded to CPU", + ) + model.cuda() cos_sim = cosine_similarity(model(input), trt_mod(input)) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -286,11 +287,12 @@ def test_bert_base_uncased_cpu_offload(ir): "offload_module_to_cpu": True, } trt_mod = torchtrt.compile(model, **compile_spec) - assertions.assertTrue( - get_model_device(model).type == "cpu", - msg="Model should be offloaded to CPU", - ) - model.cuda() + if ir == "dynamo": + assertions.assertTrue( + get_model_device(model).type == "cpu", + msg="Model should be offloaded to CPU", + ) + model.cuda() model_outputs = model(input, input2) trt_model_outputs = trt_mod(input, input2)