diff --git a/notebooks/02_how_to_generate.ipynb b/notebooks/02_how_to_generate.ipynb index 8f0bacb26b..478ea42a0a 100644 --- a/notebooks/02_how_to_generate.ipynb +++ b/notebooks/02_how_to_generate.ipynb @@ -1,26 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "02_how_to_generate.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -29,8 +13,8 @@ { "cell_type": "markdown", "metadata": { - "id": "Vp3XPuaTu9jl", - "colab_type": "text" + "colab_type": "text", + "id": "Vp3XPuaTu9jl" }, "source": [ "\n", @@ -40,8 +24,8 @@ { "cell_type": "markdown", "metadata": { - "id": "KxLvv6UaPa33", - "colab_type": "text" + "colab_type": "text", + "id": "KxLvv6UaPa33" }, "source": [ "### **Introduction**\n", @@ -64,8 +48,8 @@ { "cell_type": "markdown", "metadata": { - "id": "Si4GyYhOQMzi", - "colab_type": "text" + "colab_type": "text", + "id": "Si4GyYhOQMzi" }, "source": [ "Let's quickly install transformers and load the model. We will use GPT2 in Tensorflow 2.1 for demonstration, but the API is 1-to-1 the same for PyTorch." @@ -73,25 +57,27 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "XbzZ_IVTtoQe", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "XbzZ_IVTtoQe" }, + "outputs": [], "source": [ "!pip install -q git+https://github.com/huggingface/transformers.git\n", - "!pip install -q tensorflow==2.1" - ], - "execution_count": 0, - "outputs": [] + "!pip install -q tensorflow>=2.12\n" + ] }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "ue2kOQhXTAMU", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "ue2kOQhXTAMU" }, + "outputs": [], "source": [ "import tensorflow as tf\n", "from transformers import TFGPT2LMHeadModel, GPT2Tokenizer\n", @@ -101,15 +87,13 @@ "\n", "# add the EOS token as PAD token to avoid warnings\n", "model = TFGPT2LMHeadModel.from_pretrained(\"gpt2\", pad_token_id=tokenizer.eos_token_id)" - ], - "execution_count": 0, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "a8Y7cgu9ohXP", - "colab_type": "text" + "colab_type": "text", + "id": "a8Y7cgu9ohXP" }, "source": [ "### **Greedy Search**\n", @@ -126,28 +110,19 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "OWLd_J6lXz_t", - "colab_type": "code", - "outputId": "3b9dfd1e-21e6-44f4-f27f-8e975010f9af", "colab": { "base_uri": "https://localhost:8080/", "height": 122 - } + }, + "colab_type": "code", + "id": "OWLd_J6lXz_t", + "outputId": "3b9dfd1e-21e6-44f4-f27f-8e975010f9af" }, - "source": [ - "# encode context the generation is conditioned on\n", - "input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='tf')\n", - "\n", - "# generate text until the output length (which includes the context length) reaches 50\n", - "greedy_output = model.generate(input_ids, max_length=50)\n", - "\n", - "print(\"Output:\\n\" + 100 * '-')\n", - "print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Output:\n", @@ -155,16 +130,25 @@ "I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.\n", "\n", "I'm not sure if I'll\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "# encode context the generation is conditioned on\n", + "input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='tf')\n", + "\n", + "# generate text until the output length (which includes the context length) reaches 50\n", + "greedy_output = model.generate(input_ids, max_length=50)\n", + "\n", + "print(\"Output:\\n\" + 100 * '-')\n", + "print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))" ] }, { "cell_type": "markdown", "metadata": { - "id": "BBn1ePmJvhrl", - "colab_type": "text" + "colab_type": "text", + "id": "BBn1ePmJvhrl" }, "source": [ "Alright! We have generated our first short text with GPT2 😊. The generated words following the context are reasonable, but the model quickly starts repeating itself! This is a very common problem in language generation in general and seems to be even more so in greedy and beam search - check out [Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424) and [Shao et al., 2017](https://arxiv.org/abs/1701.03185).\n", @@ -179,8 +163,8 @@ { "cell_type": "markdown", "metadata": { - "id": "g8DnXZ1WiuNd", - "colab_type": "text" + "colab_type": "text", + "id": "g8DnXZ1WiuNd" }, "source": [ "### **Beam search**\n", @@ -198,15 +182,29 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "R1R5kx30Ynej", - "colab_type": "code", - "outputId": "574f068b-f418-48b5-8334-8451d2221032", "colab": { "base_uri": "https://localhost:8080/", "height": 102 - } + }, + "colab_type": "code", + "id": "R1R5kx30Ynej", + "outputId": "574f068b-f418-48b5-8334-8451d2221032" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.\n", + "\n", + "I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll\n" + ] + } + ], "source": [ "# activate beam search and early_stopping\n", "beam_output = model.generate(\n", @@ -218,27 +216,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(beam_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.\n", - "\n", - "I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "AZ6xs-KLi9jT", - "colab_type": "text" + "colab_type": "text", + "id": "AZ6xs-KLi9jT" }, "source": [ "While the result is arguably more fluent, the output still includes repetitions of the same word sequences. \n", @@ -249,15 +233,29 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "jy3iVJgfnkMi", - "colab_type": "code", - "outputId": "4d3e6511-711a-4594-a715-aaeb6e48e1a9", "colab": { "base_uri": "https://localhost:8080/", "height": 102 - } + }, + "colab_type": "code", + "id": "jy3iVJgfnkMi", + "outputId": "4d3e6511-711a-4594-a715-aaeb6e48e1a9" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.\n", + "\n", + "I've been thinking about this for a while now, and I think it's time for me to take a break\n" + ] + } + ], "source": [ "# set no_repeat_ngram_size to 2\n", "beam_output = model.generate(\n", @@ -270,27 +268,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(beam_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.\n", - "\n", - "I've been thinking about this for a while now, and I think it's time for me to take a break\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "nxsksOGDpmA0", - "colab_type": "text" + "colab_type": "text", + "id": "nxsksOGDpmA0" }, "source": [ "Nice, that looks much better! We can see that the repetition does not appear anymore. Nevertheless, *n-gram* penalties have to be used with care. An article generated about the city *New York* should not use a *2-gram* penalty or otherwise, the name of the city would only appear once in the whole text!\n", @@ -302,34 +286,19 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "5ClO3VphqGp6", - "colab_type": "code", - "outputId": "2296891c-024f-4fd2-9071-bff7c11a3e04", "colab": { "base_uri": "https://localhost:8080/", "height": 306 - } + }, + "colab_type": "code", + "id": "5ClO3VphqGp6", + "outputId": "2296891c-024f-4fd2-9071-bff7c11a3e04" }, - "source": [ - "# set return_num_sequences > 1\n", - "beam_outputs = model.generate(\n", - " input_ids, \n", - " max_length=50, \n", - " num_beams=5, \n", - " no_repeat_ngram_size=2, \n", - " num_return_sequences=5, \n", - " early_stopping=True\n", - ")\n", - "\n", - "# now we have 3 output sequences\n", - "print(\"Output:\\n\" + 100 * '-')\n", - "for i, beam_output in enumerate(beam_outputs):\n", - " print(\"{}: {}\".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Output:\n", @@ -349,16 +318,31 @@ "4: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.\n", "\n", "I've been thinking about this for a while now, and I think it's time for me to take a step\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "# set return_num_sequences > 1\n", + "beam_outputs = model.generate(\n", + " input_ids, \n", + " max_length=50, \n", + " num_beams=5, \n", + " no_repeat_ngram_size=2, \n", + " num_return_sequences=5, \n", + " early_stopping=True\n", + ")\n", + "\n", + "# now we have 3 output sequences\n", + "print(\"Output:\\n\" + 100 * '-')\n", + "for i, beam_output in enumerate(beam_outputs):\n", + " print(\"{}: {}\".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))" ] }, { "cell_type": "markdown", "metadata": { - "id": "HhLKyfdbsjXc", - "colab_type": "text" + "colab_type": "text", + "id": "HhLKyfdbsjXc" }, "source": [ "As can be seen, the five beam hypotheses are only marginally different to each other - which should not be too surprising when using only 5 beams.\n", @@ -380,8 +364,8 @@ { "cell_type": "markdown", "metadata": { - "id": "XbbIyK84wHq6", - "colab_type": "text" + "colab_type": "text", + "id": "XbbIyK84wHq6" }, "source": [ "### **Sampling**\n", @@ -402,15 +386,31 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "colab_type": "code", - "outputId": "1b78d191-15f6-4cbe-e2b1-23c77366fc21", - "id": "aRAz4D-Ks0_4", "colab": { "base_uri": "https://localhost:8080/", "height": 136 - } + }, + "colab_type": "code", + "id": "aRAz4D-Ks0_4", + "outputId": "1b78d191-15f6-4cbe-e2b1-23c77366fc21" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog. He just gave me a whole new hand sense.\"\n", + "\n", + "But it seems that the dogs have learned a lot from teasing at the local batte harness once they take on the outside.\n", + "\n", + "\"I take\n" + ] + } + ], "source": [ "# set seed to reproduce results. Feel free to change the seed though to get different results\n", "tf.random.set_seed(0)\n", @@ -425,29 +425,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog. He just gave me a whole new hand sense.\"\n", - "\n", - "But it seems that the dogs have learned a lot from teasing at the local batte harness once they take on the outside.\n", - "\n", - "\"I take\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "mQHuo911wfT-", - "colab_type": "text" + "colab_type": "text", + "id": "mQHuo911wfT-" }, "source": [ "Interesting! The text seems alright - but when taking a closer look, it is not very coherent. the *3-grams* *new hand sense* and *local batte harness* are very weird and don't sound like they were written by a human. That is the big problem when sampling word sequences: The models often generate incoherent gibberish, *cf.* [Ari Holtzman et al. (2019)](https://arxiv.org/abs/1904.09751).\n", @@ -466,15 +450,27 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "WgJredc-0j0Z", - "colab_type": "code", - "outputId": "a4e79355-8e3c-4788-fa21-c4e28bf61c5b", "colab": { "base_uri": "https://localhost:8080/", "height": 88 - } + }, + "colab_type": "code", + "id": "WgJredc-0j0Z", + "outputId": "a4e79355-8e3c-4788-fa21-c4e28bf61c5b" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog, but I don't like to be at home too much. I also find it a bit weird when I'm out shopping. I am always away from my house a lot, but I do have a few friends\n" + ] + } + ], "source": [ "# set seed to reproduce results. Feel free to change the seed though to get different results\n", "tf.random.set_seed(0)\n", @@ -490,25 +486,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog, but I don't like to be at home too much. I also find it a bit weird when I'm out shopping. I am always away from my house a lot, but I do have a few friends\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "kzGuu24hZZnq", - "colab_type": "text" + "colab_type": "text", + "id": "kzGuu24hZZnq" }, "source": [ "OK. There are less weird n-grams and the output is a bit more coherent now! While applying temperature can make a distribution less random, in its limit, when setting `temperature` $ \\to 0$, temperature scaled sampling becomes equal to greedy decoding and will suffer from the same problems as before. \n", @@ -518,8 +502,8 @@ { "cell_type": "markdown", "metadata": { - "id": "binNTroyzQBu", - "colab_type": "text" + "colab_type": "text", + "id": "binNTroyzQBu" }, "source": [ "### **Top-K Sampling**\n", @@ -540,15 +524,31 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "HBtDOdD0wx3l", - "colab_type": "code", - "outputId": "cfc97fac-0956-42ee-a6e5-cad14fc942d3", "colab": { "base_uri": "https://localhost:8080/", "height": 156 - } + }, + "colab_type": "code", + "id": "HBtDOdD0wx3l", + "outputId": "cfc97fac-0956-42ee-a6e5-cad14fc942d3" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog. It's so good to have an environment where your dog is available to share with you and we'll be taking care of you.\n", + "\n", + "We hope you'll find this story interesting!\n", + "\n", + "I am from\n" + ] + } + ], "source": [ "# set seed to reproduce results. Feel free to change the seed though to get different results\n", "tf.random.set_seed(0)\n", @@ -563,29 +563,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog. It's so good to have an environment where your dog is available to share with you and we'll be taking care of you.\n", - "\n", - "We hope you'll find this story interesting!\n", - "\n", - "I am from\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "Y77H5m4ZmhEX", - "colab_type": "text" + "colab_type": "text", + "id": "Y77H5m4ZmhEX" }, "source": [ "Not bad at all! The text is arguably the most *human-sounding* text so far. \n", @@ -601,8 +585,8 @@ { "cell_type": "markdown", "metadata": { - "id": "ki9LAaexzV3H", - "colab_type": "text" + "colab_type": "text", + "id": "ki9LAaexzV3H" }, "source": [ "### **Top-p (nucleus) sampling**\n", @@ -619,15 +603,33 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "EvwIc7YAx77F", - "colab_type": "code", - "outputId": "57e2b785-5dcb-4e06-9869-078b758b6a82", "colab": { "base_uri": "https://localhost:8080/", "height": 170 - } + }, + "colab_type": "code", + "id": "EvwIc7YAx77F", + "outputId": "57e2b785-5dcb-4e06-9869-078b758b6a82" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "I enjoy walking with my cute dog. He will never be the same. I watch him play.\n", + "\n", + "\n", + "Guys, my dog needs a name. Especially if he is found with wings.\n", + "\n", + "\n", + "What was that? I had a lot of\n" + ] + } + ], "source": [ "# set seed to reproduce results. Feel free to change the seed though to get different results\n", "tf.random.set_seed(0)\n", @@ -643,31 +645,13 @@ "\n", "print(\"Output:\\n\" + 100 * '-')\n", "print(tokenizer.decode(sample_output[0], skip_special_tokens=True))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "I enjoy walking with my cute dog. He will never be the same. I watch him play.\n", - "\n", - "\n", - "Guys, my dog needs a name. Especially if he is found with wings.\n", - "\n", - "\n", - "What was that? I had a lot of\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "tn-8gLaR4lat", - "colab_type": "text" + "colab_type": "text", + "id": "tn-8gLaR4lat" }, "source": [ "Great, that sounds like it could have been written by a human. Well, maybe not quite yet. \n", @@ -679,15 +663,33 @@ }, { "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "3kY8P9VG8Gi9", - "colab_type": "code", - "outputId": "6103051e-1681-4ab9-a9c1-1fad437c299d", "colab": { "base_uri": "https://localhost:8080/", "height": 190 - } + }, + "colab_type": "code", + "id": "3kY8P9VG8Gi9", + "outputId": "6103051e-1681-4ab9-a9c1-1fad437c299d" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output:\n", + "----------------------------------------------------------------------------------------------------\n", + "0: I enjoy walking with my cute dog. It's so good to have the chance to walk with a dog. But I have this problem with the dog and how he's always looking at us and always trying to make me see that I can do something\n", + "1: I enjoy walking with my cute dog, she loves taking trips to different places on the planet, even in the desert! The world isn't big enough for us to travel by the bus with our beloved pup, but that's where I find my love\n", + "2: I enjoy walking with my cute dog and playing with our kids,\" said David J. Smith, director of the Humane Society of the US.\n", + "\n", + "\"So as a result, I've got more work in my time,\" he said.\n", + "\n", + "\n" + ] + } + ], "source": [ "# set seed to reproduce results. Feel free to change the seed though to get different results\n", "tf.random.set_seed(0)\n", @@ -705,31 +707,13 @@ "print(\"Output:\\n\" + 100 * '-')\n", "for i, sample_output in enumerate(sample_outputs):\n", " print(\"{}: {}\".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Output:\n", - "----------------------------------------------------------------------------------------------------\n", - "0: I enjoy walking with my cute dog. It's so good to have the chance to walk with a dog. But I have this problem with the dog and how he's always looking at us and always trying to make me see that I can do something\n", - "1: I enjoy walking with my cute dog, she loves taking trips to different places on the planet, even in the desert! The world isn't big enough for us to travel by the bus with our beloved pup, but that's where I find my love\n", - "2: I enjoy walking with my cute dog and playing with our kids,\" said David J. Smith, director of the Humane Society of the US.\n", - "\n", - "\"So as a result, I've got more work in my time,\" he said.\n", - "\n", - "\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "-vRPfMl88rk0", - "colab_type": "text" + "colab_type": "text", + "id": "-vRPfMl88rk0" }, "source": [ "Cool, now you should have all the tools to let your model write your stories with `transformers`!" @@ -738,8 +722,8 @@ { "cell_type": "markdown", "metadata": { - "id": "NsWd7e98Vcs3", - "colab_type": "text" + "colab_type": "text", + "id": "NsWd7e98Vcs3" }, "source": [ "### **Conclusion**\n", @@ -765,8 +749,8 @@ { "cell_type": "markdown", "metadata": { - "id": "w4CYi91h11yd", - "colab_type": "text" + "colab_type": "text", + "id": "w4CYi91h11yd" }, "source": [ "### **Appendix**\n", @@ -782,5 +766,21 @@ "For more information please also look into the `generate` function [docstring](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.TFPreTrainedModel.generate)." ] } - ] + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "02_how_to_generate.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }