From bd3481896b7ac7cfcbeba43336158f841734aa70 Mon Sep 17 00:00:00 2001 From: Roei Bahumi Date: Mon, 20 Mar 2023 14:59:01 +0200 Subject: [PATCH] Added notebook torchscriptable_t5_with_torchtext.ipynb This notebook is an example of a (working) "Hacky" solution for wrapping the full 'generate' functionality inside a "forward" function. The perpose of this is to start a discussion and be a suggention on how to make the this functionality TorchScriptable. To do so, I: 1. T5TorchGenerative: inherited from T5Model: - extracted the decoding code from t5.forward() function to a standalone 'decode' function that returns a specific type. - added the GenerationUtils's 'generate' functionality as a class method (similar to HuggingFace). 2. Added TorchScriptableT5, a module that implements the full generative logic in the forward method. 3. Helper classes that build a jit (TorchScript) model from a predefined T5 Bundle --- .../torchscriptable_t5_with_torchtext.ipynb | 768 ++++++++++++++++++ 1 file changed, 768 insertions(+) create mode 100644 notebooks/torchscriptable_t5_with_torchtext.ipynb diff --git a/notebooks/torchscriptable_t5_with_torchtext.ipynb b/notebooks/torchscriptable_t5_with_torchtext.ipynb new file mode 100644 index 0000000000..d7e89cd3b6 --- /dev/null +++ b/notebooks/torchscriptable_t5_with_torchtext.ipynb @@ -0,0 +1,768 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "c5568aa5-5a8b-410f-bbab-bf6069d4c461", + "showInput": true + }, + "source": [ + "# TorchScriptable T5 with TorchText\n", + "## Motivation\n", + "\n", + "[TorchScript](https://pytorch.org/docs/stable/jit.html) is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons. \n", + "\n", + "The new PyTorch version introduced the [GenerationUtils](https://github.com/pytorch/text/blob/1b72eba0a07295d74d168c99fd8a5586a0943aa3/torchtext/prototype/generate.py#L13) functionality. It allows wrapping TorchText's [T5Model](https://github.com/pytorch/text/blob/670e52a3df658f6332f2904cfed67308f3f5adce/torchtext/models/t5/model.py#L67), and using it to generate text in a similar way to the [HuggingFace 'generate'](https://huggingface.co/docs/transformers/v4.27.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) function. However, although both T5Model and its tokenizer are initially \"Torchscriptle\", this property is not preserved after wrapping the model with GenerationUtils. \n", + "\n", + "\n", + "## Technical details\n", + "We've implemented a \"Hacky\" solution for wrapping the full 'generate' functionality inside a \"forward\" function. We will work with the Pytorch team and **hopefully, this code (with some modifications) can later later added to TorchText's T5Model.**.\n", + "\n", + "To do so, we:\n", + "1. T5TorchGenerative: inherited from T5Model:\n", + "- extracted the decoding code from t5.forward() function to a standalone 'decode' function that returns a specific type. \n", + "- added the GenerationUtils's 'generate' functionality as a class method (similar to HuggingFace).\n", + "2. Added TorchScriptableT5, a module that implements the full generative logic in the forward method.\n", + "3. Helper classes that build a jit (TorchScript) model from a predefined T5 Bundle\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "01366f6a-342d-469a-bbc5-8b21535d768e", + "showInput": false + }, + "source": [ + "## Issue Example\n", + "Currently, this code raises an exception." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679307290669, + "executionStopTime": 1679307294521, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "a415e457-2a5d-4ded-9260-4abe5389c627", + "requestMsgId": "0cef735b-b207-469d-9ff5-f39c4b0b9806", + "showInput": true + }, + "outputs": [ + { + "ename": "NotSupportedError", + "evalue": "Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:\n File \"/Users/rbahumi/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torchtext/prototype/generate.py\", line 37\n def __init__(self, model: nn.Module, **kwargs) -> None:\n ~~~~~~~ <--- HERE\n self.model = model\n self.is_encoder_decoder = kwargs.pop(\"is_encoder_decoder\", True)\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotSupportedError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 19\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# But after wrapping with GenerationUtils, the model is no longer torchscriptable\u001b[39;00m\n\u001b[1;32m 18\u001b[0m generative_model \u001b[38;5;241m=\u001b[39m GenerationUtils(model)\n\u001b[0;32m---> 19\u001b[0m generative_model_jit \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscript\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgenerative_model\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/_script.py:1351\u001b[0m, in \u001b[0;36mscript\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1349\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn\n\u001b[1;32m 1350\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_recursive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_script_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/_recursive.py:448\u001b[0m, in \u001b[0;36mcreate_script_class\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 446\u001b[0m rcb \u001b[38;5;241m=\u001b[39m _jit_internal\u001b[38;5;241m.\u001b[39mcreateResolutionCallbackForClassMethods(\u001b[38;5;28mtype\u001b[39m(obj))\n\u001b[1;32m 447\u001b[0m \u001b[38;5;66;03m# Script the type of obj if it hasn't already been scripted.\u001b[39;00m\n\u001b[0;32m--> 448\u001b[0m \u001b[43m_compile_and_register_class\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrcb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mqualified_class_name\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 449\u001b[0m class_ty \u001b[38;5;241m=\u001b[39m _python_cu\u001b[38;5;241m.\u001b[39mget_class(qualified_class_name)\n\u001b[1;32m 450\u001b[0m \u001b[38;5;66;03m# Create an empty torch._C.ScriptObject with the scripted type.\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/_recursive.py:49\u001b[0m, in \u001b[0;36m_compile_and_register_class\u001b[0;34m(obj, rcb, qualified_name)\u001b[0m\n\u001b[1;32m 46\u001b[0m script_class \u001b[38;5;241m=\u001b[39m _get_script_class(obj)\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m script_class:\n\u001b[0;32m---> 49\u001b[0m ast \u001b[38;5;241m=\u001b[39m \u001b[43mget_jit_class_def\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m defaults \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mfrontend\u001b[38;5;241m.\u001b[39mget_default_args_for_class(obj)\n\u001b[1;32m 51\u001b[0m script_class \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_jit_script_class_compile(qualified_name, ast, defaults, rcb)\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/frontend.py:234\u001b[0m, in \u001b[0;36mget_jit_class_def\u001b[0;34m(cls, self_name)\u001b[0m\n\u001b[1;32m 231\u001b[0m func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name)\n\u001b[1;32m 232\u001b[0m _jit_internal\u001b[38;5;241m.\u001b[39mloader\u001b[38;5;241m.\u001b[39mcache(func, parsed_def\u001b[38;5;241m.\u001b[39msource)\n\u001b[0;32m--> 234\u001b[0m method_defs \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 235\u001b[0m get_jit_def(obj, name, self_name\u001b[38;5;241m=\u001b[39mself_name, is_classmethod\u001b[38;5;241m=\u001b[39mis_classmethod(obj))\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (name, obj) \u001b[38;5;129;01min\u001b[39;00m methods\n\u001b[1;32m 237\u001b[0m ]\n\u001b[1;32m 238\u001b[0m properties \u001b[38;5;241m=\u001b[39m get_class_properties(\u001b[38;5;28mcls\u001b[39m, self_name)\n\u001b[1;32m 240\u001b[0m leading_whitespace_len \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(source\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mlen\u001b[39m(dedent_src\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m])\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/frontend.py:235\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 231\u001b[0m func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name)\n\u001b[1;32m 232\u001b[0m _jit_internal\u001b[38;5;241m.\u001b[39mloader\u001b[38;5;241m.\u001b[39mcache(func, parsed_def\u001b[38;5;241m.\u001b[39msource)\n\u001b[1;32m 234\u001b[0m method_defs \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 235\u001b[0m \u001b[43mget_jit_def\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mself_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_classmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_classmethod\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (name, obj) \u001b[38;5;129;01min\u001b[39;00m methods\n\u001b[1;32m 237\u001b[0m ]\n\u001b[1;32m 238\u001b[0m properties \u001b[38;5;241m=\u001b[39m get_class_properties(\u001b[38;5;28mcls\u001b[39m, self_name)\n\u001b[1;32m 240\u001b[0m leading_whitespace_len \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(source\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mlen\u001b[39m(dedent_src\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m0\u001b[39m])\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/frontend.py:297\u001b[0m, in \u001b[0;36mget_jit_def\u001b[0;34m(fn, def_name, self_name, is_classmethod)\u001b[0m\n\u001b[1;32m 294\u001b[0m qualname \u001b[38;5;241m=\u001b[39m get_qualified_name(fn)\n\u001b[1;32m 295\u001b[0m pdt_arg_types \u001b[38;5;241m=\u001b[39m type_trace_db\u001b[38;5;241m.\u001b[39mget_args_types(qualname)\n\u001b[0;32m--> 297\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbuild_def\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparsed_def\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfn_def\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtype_line\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdef_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mself_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpdt_arg_types\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpdt_arg_types\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/frontend.py:335\u001b[0m, in \u001b[0;36mbuild_def\u001b[0;34m(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)\u001b[0m\n\u001b[1;32m 330\u001b[0m body \u001b[38;5;241m=\u001b[39m py_def\u001b[38;5;241m.\u001b[39mbody\n\u001b[1;32m 331\u001b[0m r \u001b[38;5;241m=\u001b[39m ctx\u001b[38;5;241m.\u001b[39mmake_range(py_def\u001b[38;5;241m.\u001b[39mlineno,\n\u001b[1;32m 332\u001b[0m py_def\u001b[38;5;241m.\u001b[39mcol_offset,\n\u001b[1;32m 333\u001b[0m py_def\u001b[38;5;241m.\u001b[39mcol_offset \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdef\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m--> 335\u001b[0m param_list \u001b[38;5;241m=\u001b[39m \u001b[43mbuild_param_list\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpy_def\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mself_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpdt_arg_types\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 336\u001b[0m return_type \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(py_def, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreturns\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torch/jit/frontend.py:359\u001b[0m, in \u001b[0;36mbuild_param_list\u001b[0;34m(ctx, py_args, self_name, pdt_arg_types)\u001b[0m\n\u001b[1;32m 357\u001b[0m expr \u001b[38;5;241m=\u001b[39m py_args\u001b[38;5;241m.\u001b[39mkwarg\n\u001b[1;32m 358\u001b[0m ctx_range \u001b[38;5;241m=\u001b[39m ctx\u001b[38;5;241m.\u001b[39mmake_range(expr\u001b[38;5;241m.\u001b[39mlineno, expr\u001b[38;5;241m.\u001b[39mcol_offset \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m, expr\u001b[38;5;241m.\u001b[39mcol_offset \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlen\u001b[39m(expr\u001b[38;5;241m.\u001b[39marg))\n\u001b[0;32m--> 359\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m NotSupportedError(ctx_range, _vararg_kwarg_err)\n\u001b[1;32m 360\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m py_args\u001b[38;5;241m.\u001b[39mvararg \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 361\u001b[0m expr \u001b[38;5;241m=\u001b[39m py_args\u001b[38;5;241m.\u001b[39mvararg\n", + "\u001b[0;31mNotSupportedError\u001b[0m: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:\n File \"/Users/rbahumi/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torchtext/prototype/generate.py\", line 37\n def __init__(self, model: nn.Module, **kwargs) -> None:\n ~~~~~~~ <--- HERE\n self.model = model\n self.is_encoder_decoder = kwargs.pop(\"is_encoder_decoder\", True)\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import torch\n", + "from torchtext.prototype.generate import GenerationUtils\n", + "from torchtext.models import T5_SMALL_GENERATION\n", + "\n", + "# The tokenizer object is torchscriptable\n", + "tokenizer = T5_SMALL_GENERATION.transform()\n", + "tokenizer_jit = torch.jit.script(tokenizer)\n", + "\n", + "# The T5 model is also torchscriptable\n", + "model = T5_SMALL_GENERATION.get_model()\n", + "model_jit = torch.jit.script(model)\n", + "\n", + "\n", + "# But after wrapping with GenerationUtils, the model is no longer torchscriptable\n", + "generative_model = GenerationUtils(model)\n", + "generative_model_jit = torch.jit.script(generative_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "e6d9ecd2-3354-4425-bdbe-7a1cb5681933", + "showInput": false + }, + "source": [ + "This failure is caused by: \n", + "1. The use of keyword argument from a dictionary (**kwargs) \n", + "2. Functions that can accept Optional values \n", + "3. Multiple optiones for returned types.\n", + "\n", + "\n", + "In the next section, we suggest a (currently) hacky solution to solve this. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "26118f6c-ccbc-4ac3-8922-ba1a8a0497e8", + "showInput": false + }, + "source": [ + "# Generate results using T5TorchGenerative\n", + "We'll define a new class called T5TorchGenerative (subclass of T5Model) that will \n", + "\n", + "We've implemented a \"Hacky\" solution for wrapping the full 'generate' functionality inside a \"forward\" function. We will work with the Pytorch team to make it an appropriate pull request.\n", + "\n", + "To do so, we:\n", + "1. T5TorchGenerative: inherited from T5Model:\n", + "- extracted the decoding code from t5.forward() function to a standalone 'decode' function that returns a specific type. \n", + "- added the GenerationUtils's 'generate' functionality as a class method (similar to HuggingFace).\n", + "2. Added TorchScriptableT5, a module that implements the full generative logic in the forward method.\n", + "3. Helper classes that build a jit (TorchScript) model from a predefined T5 Bundle" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679307299442, + "executionStopTime": 1679307299550, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "9aec1d2d-3eb3-4843-a25b-8ef0faa0b62f", + "requestMsgId": "d25f4230-cd85-4d1c-9fbc-295f1cf497f3", + "showInput": true + }, + "outputs": [], + "source": [ + "from typing import List, Optional\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor\n", + "from torchtext.models import T5Model, T5Conf\n", + "from torchtext.models.t5.modules import PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder, ENCODER_OUTPUTS_TYPE\n", + "\n", + "\n", + "DEFAULT_MAX_SEQ_LEN = 256\n", + "\n", + "\n", + "class T5TorchGenerative(T5Model):\n", + " \"\"\"\n", + " This is a quick and dirty implementation for the T5Model model which encapsulates the GenerationUtils functionality\n", + " inside the instance.\n", + "\n", + " Motivation: the ability to make a generate functionality TorchScriptable.\n", + "\n", + " TODO: implement beam search once it is added to GenerationUtils.\n", + "\n", + " \"\"\"\n", + " @torch.jit.export\n", + " def _prepare_decoder_ids_for_generation(\n", + " self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None\n", + " ):\n", + " return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx\n", + "\n", + " @torch.jit.export\n", + " def decode(\n", + " self,\n", + " encoder_outputs: ENCODER_OUTPUTS_TYPE,\n", + " decoder_tokens: Optional[Tensor] = None,\n", + " encoder_mask: Optional[Tensor] = None,\n", + " decoder_mask: Optional[Tensor] = None,\n", + " encoder_padding_mask: Optional[Tensor] = None,\n", + " decoder_padding_mask: Optional[Tensor] = None,\n", + " past_key_values: Optional[List[PAST_KEY_VALUES_TYPE]] = None,\n", + " return_past_key_values: bool = False,\n", + " ) -> Tensor:\n", + " \"\"\"\n", + " This method's code was copied from the T5Model::forward() function. \n", + " It only does the decoder part, and returns a tensor instead of multiple return values wrapped in a dictionary type.\n", + "\n", + " In the future, it might be helpful if we can call this function from forward, and remove the duplicate code. \n", + " \"\"\"\n", + "\n", + " assert self.decoder is not None\n", + " assert encoder_outputs is not None\n", + "\n", + " encoder_output = encoder_outputs.get(\"encoder_output\")\n", + " assert torch.jit.isinstance(encoder_output, Tensor)\n", + "\n", + " batch_size = encoder_output.size(0)\n", + " encoder_output_device = encoder_output.device\n", + "\n", + " # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.\n", + " if decoder_tokens is None:\n", + " decoder_tokens = (\n", + " torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx\n", + " )\n", + "\n", + " if decoder_padding_mask is None:\n", + " decoder_padding_mask = decoder_tokens.eq(self.padding_idx)\n", + " # T5 implemention uses padding idx to start sequence. Want to ignore this when masking\n", + " decoder_padding_mask[:, 0] = False\n", + "\n", + " decoder_embeddings = self.token_embeddings(decoder_tokens)\n", + " decoder_outputs = self.decoder(\n", + " decoder_embeddings,\n", + " memory=encoder_output,\n", + " tgt_mask=decoder_mask,\n", + " memory_mask=encoder_mask,\n", + " tgt_key_padding_mask=decoder_padding_mask,\n", + " memory_key_padding_mask=encoder_padding_mask,\n", + " past_key_values=past_key_values,\n", + " return_past_key_values=return_past_key_values,\n", + " )\n", + "\n", + " decoder_output = decoder_outputs.get(\"decoder_output\")\n", + " assert torch.jit.isinstance(decoder_output, Tensor)\n", + "\n", + " if self.linear_head:\n", + " assert self.lm_head is not None\n", + " # Rescale output before projecting on vocab. This happens when the encoder and decoder share the\n", + " # same word embeddings, which is always the case in our t5 implementation.\n", + " # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661\n", + " decoder_output = decoder_output * (self.embedding_dim ** -0.5)\n", + " decoder_output = self.lm_head(decoder_output)\n", + "\n", + " return decoder_output\n", + "\n", + " @torch.jit.export\n", + " def greedy_search(\n", + " self, input_ids: torch.Tensor, max_length: int, eos_idx: int, encoder_outputs: ENCODER_OUTPUTS_TYPE, pad_idx: Optional[int] = None,\n", + " ) -> torch.Tensor:\n", + " \"\"\"Greedy search decoding for text generation. Takes the most likely next token every time.\n", + "\n", + " Inputs:\n", + " input_ids (Tensor): Text prompt(s) for greedy generation.\n", + " max_length (int): Max length to generate responses.\n", + " eos_idx (int): End of sequence index.\n", + " pad_idx (int): Padding index.\n", + "\n", + " Returns:\n", + " Batch of sequences decoded by greedy search.\n", + " \"\"\"\n", + " unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long)\n", + "\n", + " while True:\n", + " decoder_output = self.decode(\n", + " decoder_tokens=input_ids,\n", + " encoder_mask=None,\n", + " decoder_mask=None,\n", + " encoder_padding_mask=None,\n", + " decoder_padding_mask=None,\n", + " encoder_outputs=encoder_outputs,\n", + " past_key_values=None,\n", + " return_past_key_values=True\n", + " )\n", + "\n", + " # Calculate probabilities and take the most likely next token\n", + " probs = F.log_softmax(decoder_output[:, -1], dim=-1)\n", + " _, next_tokens = torch.topk(probs, 1)\n", + "\n", + " # For any finished sequences, padding idx should be the last token\n", + " if eos_idx is not None:\n", + " if pad_idx is not None:\n", + " next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)\n", + "\n", + " # Append the next tokens to the previous tokens\n", + " input_ids = torch.cat([input_ids, next_tokens], dim=-1)\n", + "\n", + " if eos_idx is not None:\n", + " unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())\n", + "\n", + " # Stop iterating once all sequences are finished or exceed the max_length\n", + " if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length:\n", + " break\n", + "\n", + " return input_ids\n", + "\n", + " @torch.jit.export\n", + " def generate(\n", + " self,\n", + " inputs: torch.Tensor,\n", + " num_beams: Optional[int] = None,\n", + " max_length: int = DEFAULT_MAX_SEQ_LEN,\n", + " pad_idx: int = 0,\n", + " eos_idx: int = 1,\n", + " ) -> torch.Tensor:\n", + " encoder_outputs = self.encoder(inputs)\n", + " inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, pad_idx=pad_idx)\n", + "\n", + " if num_beams is None or num_beams == 1:\n", + " return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, encoder_outputs=encoder_outputs)\n", + " # elif num_beams > 1:\n", + " # return self.beam_search(inputs, num_beams, max_length)\n", + " else:\n", + " raise ValueError(\"`num_beams` must be >= 1.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "customInput": null, + "originalKey": "6e4c730f-d5e0-46bc-aae2-462a639dd4a1", + "showInput": true + }, + "outputs": [], + "source": [ + "from typing import Optional, Union, Dict, Any\n", + "from torchtext import _TEXT_BUCKET\n", + "from urllib.parse import urljoin\n", + "from torchtext._download_hooks import load_state_dict_from_url\n", + "\n", + "\n", + "def build_model(\n", + " config: T5Conf,\n", + " T5Class=T5Model,\n", + " freeze_model: bool = False,\n", + " checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,\n", + " strict: bool = False,\n", + " dl_kwargs: Optional[Dict[str, Any]] = None,\n", + ") -> T5Model:\n", + " \"\"\"Class builder method that can overide the default T5Model model class \n", + " \n", + " (reference: https://github.com/pytorch/text/blob/a1dc61b8e80df70fe7a35b9f5f5cc7e19c7dd8a3/torchtext/models/t5/bundler.py#L113)\n", + " \n", + " Args:\n", + " config (T5Conf): An instance of classT5Conf that defined the model configuration\n", + " freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`)\n", + " checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)\n", + " strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`)\n", + " dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`)\n", + " \"\"\"\n", + " model = T5Class(config, freeze_model)\n", + " if checkpoint is not None:\n", + " if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):\n", + " state_dict = checkpoint\n", + " elif isinstance(checkpoint, str):\n", + " dl_kwargs = {} if dl_kwargs is None else dl_kwargs\n", + " state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)\n", + " else:\n", + " raise TypeError(\n", + " \"checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}\".format(type(checkpoint))\n", + " )\n", + "\n", + " model.load_state_dict(state_dict, strict=strict)\n", + "\n", + " return model\n", + "\n", + "\n", + "def load_model(bundle, T5Class=T5TorchGenerative):\n", + " \"\"\"\n", + " \n", + " Example usage:\n", + " >> model = load_model(bundle=T5_SMALL_GENERATION, T5Class=T5TorchGenerative)\n", + " \"\"\"\n", + " return build_model(config=bundle.config, T5Class=T5Class, checkpoint=bundle._path)\n", + "\n", + "\n", + "def get_model_from_bundle(bundle):\n", + " model = load_model(bundle=bundle, T5Class=T5TorchGenerative)\n", + " tokenizer = bundle.transform()\n", + " full_model = TorchScriptableT5(model=model, transform=tokenizer)\n", + " return full_model\n", + "\n", + "def get_jit_from_bundle(bundle):\n", + " full_model = get_model_from_bundle(bundle)\n", + " full_model_jit = torch.jit.script(full_model)\n", + " return full_model_jit" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "customInput": null, + "originalKey": "aaae0f95-a9c8-4704-8d5d-924aa7084829", + "showInput": true + }, + "outputs": [], + "source": [ + "from typing import List, Union\n", + "\n", + "DEFAULT_MAX_LENGHT: int = 100\n", + "\n", + "\n", + "class TorchScriptableT5(torch.nn.Module):\n", + " def __init__(self, model, transform, cuda: bool = False):\n", + " super(TorchScriptableT5, self).__init__()\n", + " self.cuda = cuda\n", + " self.transform = transform\n", + " \n", + " if cuda:\n", + " model = model.cuda()\n", + " \n", + " self.model = model\n", + " self.model.eval()\n", + "\n", + " def forward(self, texts: List[str], max_length:int=DEFAULT_MAX_LENGHT) -> Union[List[str], str]:\n", + " input_ids = self.transform(texts)\n", + " if self.cuda:\n", + " input_ids = input_ids.cuda()\n", + " raw_outputs = self.model.generate(input_ids, max_length=max_length)\n", + " \n", + " if raw_outputs.dim() == 1:\n", + " raw_outputs_list: List[List[int]] = raw_outputs[None, :].tolist()\n", + " else:\n", + " raw_outputs_list: List[List[int]] = raw_outputs.tolist() # : List[List[int]] = raw_outputs.tolist()\n", + "\n", + "\n", + " res = self.transform.decode(raw_outputs_list)\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679308113742, + "executionStopTime": 1679308113754, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "df01f82d-d6e4-4d30-bbeb-cdaf76c18710", + "requestMsgId": "568181bd-e849-4b21-b02c-be42c34c84f9", + "showInput": true + }, + "outputs": [], + "source": [ + "from typing import Optional, Union, Dict, Any\n", + "from torchtext import _TEXT_BUCKET\n", + "from urllib.parse import urljoin\n", + "from torchtext._download_hooks import load_state_dict_from_url\n", + "\n", + "\n", + "def build_model(\n", + " config: T5Conf,\n", + " T5Class=T5Model,\n", + " freeze_model: bool = False,\n", + " checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,\n", + " strict: bool = False,\n", + " dl_kwargs: Optional[Dict[str, Any]] = None,\n", + ") -> T5Model:\n", + " \"\"\"Class builder method that can overide the default T5Model model class \n", + " \n", + " (reference: https://github.com/pytorch/text/blob/a1dc61b8e80df70fe7a35b9f5f5cc7e19c7dd8a3/torchtext/models/t5/bundler.py#L113)\n", + " \n", + " Args:\n", + " config (T5Conf): An instance of classT5Conf that defined the model configuration\n", + " freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`)\n", + " checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)\n", + " strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`)\n", + " dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`)\n", + " \"\"\"\n", + " model = T5Class(config, freeze_model)\n", + " if checkpoint is not None:\n", + " if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):\n", + " state_dict = checkpoint\n", + " elif isinstance(checkpoint, str):\n", + " dl_kwargs = {} if dl_kwargs is None else dl_kwargs\n", + " state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)\n", + " else:\n", + " raise TypeError(\n", + " \"checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}\".format(type(checkpoint))\n", + " )\n", + "\n", + " model.load_state_dict(state_dict, strict=strict)\n", + "\n", + " return model\n", + "\n", + "\n", + "def load_model(bundle, T5Class=T5TorchGenerative):\n", + " \"\"\"\n", + " \n", + " Example usage:\n", + " >> model = load_model(bundle=T5_SMALL_GENERATION, T5Class=T5TorchGenerative)\n", + " \"\"\"\n", + " return build_model(config=bundle.config, T5Class=T5Class, checkpoint=bundle._path)\n", + "\n", + "\n", + "def get_model_from_bundle(bundle, cuda=False):\n", + " model = load_model(bundle=bundle, T5Class=T5TorchGenerative)\n", + " tokenizer = bundle.transform()\n", + " full_model = TorchScriptableT5(model=model, transform=tokenizer, cuda=cuda)\n", + " return full_model\n", + "\n", + "def get_jit_from_bundle(bundle, cuda=False):\n", + " full_model = get_model_from_bundle(bundle, cuda=cuda)\n", + " full_model_jit = torch.jit.script(full_model)\n", + " return full_model_jit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "e8fcc480-0dc3-4081-b5d1-dc4f74f1eb9e", + "showInput": false + }, + "source": [ + "# The new model is an E2E generation model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679307307093, + "executionStopTime": 1679307307174, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "07417dd1-6387-49bc-b815-69f6fffb9cff", + "requestMsgId": "b48eafaa-9445-4047-92dc-5094efa9408c", + "showInput": true + }, + "outputs": [], + "source": [ + "SUMMERIZE_PROMP = \"summarize\"\n", + "TRANSLATE_TO_GERMAN = \"translate English to German\"\n", + "QUESTION_PROMPS = \"question\"\n", + "CONTEXT_PROMPT = \"context\"\n", + "\n", + "\n", + "def summarize_text(text):\n", + " return f\"{SUMMERIZE_PROMP}: {text}\"\n", + "\n", + "\n", + "def en_to_german_text(text):\n", + " return f\"{TRANSLATE_TO_GERMAN}: {text}\"\n", + "\n", + "\n", + "def qa_text(context, question):\n", + " return f\"{QUESTION_PROMPS}: {question}? {CONTEXT_PROMPT}: {context}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679307307671, + "executionStopTime": 1679307307685, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "36fb048b-7e42-4286-9fe4-35ac9d642574", + "requestMsgId": "e8a145b0-293b-4a21-ab8e-9877d57321cd", + "showInput": true + }, + "outputs": [], + "source": [ + "from torchtext.models import T5_SMALL_GENERATION, T5_LARGE_GENERATION, T5_3B_GENERATION, T5_11B_GENERATION" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679307311187, + "executionStopTime": 1679307355677, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "1be15ada-4e98-4302-ae01-cb6dd795aa99", + "requestMsgId": "5d59d649-02a2-4087-a7ce-95a705cf591d", + "showInput": true + }, + "outputs": [], + "source": [ + "EXAMPLE_INPUT = [\n", + " 'question: What does Nir likes to eat? context: Nir is a PM on the Care AI team. Nir only eats vegeterian food and he loves Pizza',\n", + " 'question: Who likes to eat pizza? context: Nir is a PM on the Care AI team. Nir only eats vegeterian food and he loves Pizza',\n", + " \"summarize: studies say that owning a dog is good for you\",\n", + "]\n", + "\n", + "t5_large = get_jit_from_bundle(T5_LARGE_GENERATION)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679308141879, + "executionStopTime": 1679308149446, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "4a8e3147-4281-4835-b0a9-32b88863e351", + "requestMsgId": "f9d1bd7b-763f-43db-a0c9-6b6579955a32", + "showInput": true + }, + "outputs": [], + "source": [ + "%time t5_large(EXAMPLE_INPUT, max_length=100)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679308151928, + "executionStopTime": 1679308154907, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "382ff62f-6c73-4774-b046-cf68212048ae", + "requestMsgId": "76d1674d-14e5-400e-a4ab-bf48b510ac98", + "showInput": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.93 s, sys: 47.5 ms, total: 2.98 s\n", + "Wall time: 2.95 s\n" + ] + }, + { + "data": { + "text/plain": [ + "['Pizza',\n", + " 'Nir',\n", + " 'studies say owning a dog is good for you . a dog is a good companion, a companion for life .']" + ] + }, + "execution_count": 29, + "metadata": { + "bento_obj_id": "140269962326656" + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Try to load to GPU and compare the time difference \n", + "t5_large_gpu = get_jit_from_bundle(T5_LARGE_GENERATION, cuda=True)\n", + "%time t5_large_gpu(EXAMPLE_INPUT, max_length=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "originalKey": "86c8e333-5043-4bbe-9786-c8c08594bd65", + "showInput": false + }, + "source": [ + "### Save the model localy as jit" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679308789621, + "executionStopTime": 1679308800224, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "72d63803-0c88-4663-b306-79f124b1f886", + "requestMsgId": "a3807090-b58b-4206-bf09-8b47e6cfada9", + "showInput": true + }, + "outputs": [], + "source": [ + "model_filename = 'flan_t5_large_generation.pt'\n", + "torch.jit.save(t5_large, model_filename)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1679308809545, + "executionStopTime": 1679308810651, + "jupyter": { + "outputs_hidden": false + }, + "originalKey": "2b09f5d2-b071-4947-a6cb-48271506f297", + "requestMsgId": "17f99907-d5ee-4650-a6b1-b7bb180ddb61", + "showInput": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 2.9G\n", + "-rw-r--r-- 1 rbahumi rbahumi 2.9G Mar 20 03:40 flan_t5_large_generation.pt\n", + "drwxr-xr-x 1 rbahumi rbahumi 1.1K Mar 20 03:39 .\n" + ] + } + ], + "source": [ + "!ls -lath | head -3" + ] + } + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "dataExplorerConfig": {}, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "last_base_url": "https://bento.edge.x2p.facebook.net/", + "last_kernel_id": "2e805df5-4331-4d4d-bcf5-5867cccba280", + "last_msg_id": "fb91432f-260f2394c90f5f85446bf740_1027", + "last_server_session_id": "fdf51b3a-d901-4add-92aa-395a7bc782bd", + "outputWidgetContext": {} + }, + "nbformat": 4, + "nbformat_minor": 4 +}