|
12 | 12 |
|
13 | 13 | import pytest
|
14 | 14 | import torch
|
15 |
| -from tensordict import lazy_stack, set_list_to_stack, TensorDict |
| 15 | +from tensordict import assert_close, lazy_stack, set_list_to_stack, TensorDict |
16 | 16 |
|
17 | 17 | from tensordict.utils import _zip_strict
|
18 | 18 | from torchrl.data.llm import History
|
@@ -163,6 +163,22 @@ def sample_tokens(vllm_instance):
|
163 | 163 | return tokenized["input_ids"], tokenized["attention_mask"]
|
164 | 164 |
|
165 | 165 |
|
| 166 | +@pytest.fixture |
| 167 | +def sample_tokens_unpadded(vllm_instance): |
| 168 | + """Create sample tokens for testing.""" |
| 169 | + model, tokenizer = vllm_instance |
| 170 | + text = [ |
| 171 | + "Are you happy? Say yes or no.", |
| 172 | + "Explain the difference between a cat and a dog. Be very detailed.", |
| 173 | + ] |
| 174 | + tokenized = tokenizer(text, padding=False) |
| 175 | + return torch.nested.nested_tensor( |
| 176 | + [torch.tensor(t) for t in tokenized["input_ids"]], layout=torch.jagged |
| 177 | + ), torch.nested.nested_tensor( |
| 178 | + [torch.tensor(t) for t in tokenized["attention_mask"]], layout=torch.jagged |
| 179 | + ) |
| 180 | + |
| 181 | + |
166 | 182 | def check_output_shapes(out, pad_output, requested_log_probs=False):
|
167 | 183 | if pad_output or not out.ndim:
|
168 | 184 | # We can get all tensors or they are none
|
@@ -1656,8 +1672,6 @@ def test_log_probs_consistency(
|
1656 | 1672 | vllm_lp_result = vllm_lp_wrapper(new_data.copy())
|
1657 | 1673 | tf_lp_result = tf_lp_wrapper(new_data.copy())
|
1658 | 1674 |
|
1659 |
| - from tensordict import assert_close |
1660 |
| - |
1661 | 1675 | assert_close(
|
1662 | 1676 | vllm_lp_result, tf_lp_result, atol=1e-1, rtol=1e-1, intersection=True
|
1663 | 1677 | )
|
@@ -1825,6 +1839,100 @@ def test_transformers_custom_masking(
|
1825 | 1839 | assert hasattr(dist, "log_prob")
|
1826 | 1840 |
|
1827 | 1841 |
|
| 1842 | +@pytest.mark.skipif(not _has_transformers, reason="transformers not available") |
| 1843 | +@pytest.mark.parametrize("pad_output", [False, True]) |
| 1844 | +class TestPacking: |
| 1845 | + def test_packing_history( |
| 1846 | + self, transformers_instance, sample_history_assistant, pad_output |
| 1847 | + ): |
| 1848 | + model, tokenizer = transformers_instance |
| 1849 | + |
| 1850 | + wrapper_packed = TransformersWrapper( |
| 1851 | + model, |
| 1852 | + tokenizer=tokenizer, |
| 1853 | + input_mode="history", |
| 1854 | + generate=False, |
| 1855 | + return_log_probs=True, |
| 1856 | + pad_output=pad_output, |
| 1857 | + pad_model_input=False, |
| 1858 | + ) |
| 1859 | + wrapped_padded = TransformersWrapper( |
| 1860 | + model, |
| 1861 | + tokenizer=tokenizer, |
| 1862 | + input_mode="history", |
| 1863 | + generate=False, |
| 1864 | + return_log_probs=True, |
| 1865 | + pad_output=pad_output, |
| 1866 | + pad_model_input=True, |
| 1867 | + ) |
| 1868 | + |
| 1869 | + td = TensorDict( |
| 1870 | + {"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,) |
| 1871 | + ).to_lazystack(0) |
| 1872 | + |
| 1873 | + result_padded = wrapped_padded(td) |
| 1874 | + result_packed = wrapper_packed(td) |
| 1875 | + assert_close(result_packed["log_probs"], result_padded["log_probs"]) |
| 1876 | + |
| 1877 | + def test_packing_text(self, transformers_instance, sample_text, pad_output): |
| 1878 | + model, tokenizer = transformers_instance |
| 1879 | + wrapper_packed = TransformersWrapper( |
| 1880 | + model, |
| 1881 | + tokenizer=tokenizer, |
| 1882 | + input_mode="text", |
| 1883 | + generate=False, |
| 1884 | + return_log_probs=True, |
| 1885 | + pad_output=pad_output, |
| 1886 | + pad_model_input=False, |
| 1887 | + ) |
| 1888 | + wrapped_padded = TransformersWrapper( |
| 1889 | + model, |
| 1890 | + tokenizer=tokenizer, |
| 1891 | + input_mode="text", |
| 1892 | + generate=False, |
| 1893 | + return_log_probs=True, |
| 1894 | + pad_output=pad_output, |
| 1895 | + pad_model_input=True, |
| 1896 | + ) |
| 1897 | + td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,)) |
| 1898 | + result_packed = wrapper_packed(td) |
| 1899 | + result_padded = wrapped_padded(td) |
| 1900 | + assert_close(result_packed["log_probs"], result_padded["log_probs"]) |
| 1901 | + |
| 1902 | + def test_packing_tokens( |
| 1903 | + self, transformers_instance, sample_tokens_unpadded, pad_output |
| 1904 | + ): |
| 1905 | + model, tokenizer = transformers_instance |
| 1906 | + wrapper_packed = TransformersWrapper( |
| 1907 | + model, |
| 1908 | + tokenizer=tokenizer, |
| 1909 | + input_mode="tokens", |
| 1910 | + generate=False, |
| 1911 | + return_log_probs=True, |
| 1912 | + pad_output=pad_output, |
| 1913 | + pad_model_input=False, |
| 1914 | + ) |
| 1915 | + wrapped_padded = TransformersWrapper( |
| 1916 | + model, |
| 1917 | + tokenizer=tokenizer, |
| 1918 | + input_mode="tokens", |
| 1919 | + generate=False, |
| 1920 | + return_log_probs=True, |
| 1921 | + pad_output=pad_output, |
| 1922 | + pad_model_input=True, |
| 1923 | + ) |
| 1924 | + td = TensorDict( |
| 1925 | + { |
| 1926 | + "tokens": Tokens(full=sample_tokens_unpadded[0]), |
| 1927 | + "masks": Masks(all_attention_mask=sample_tokens_unpadded[1]), |
| 1928 | + }, |
| 1929 | + batch_size=(2,), |
| 1930 | + ).to_lazystack(0) |
| 1931 | + result_padded = wrapped_padded(td) |
| 1932 | + result_packed = wrapper_packed(td) |
| 1933 | + assert_close(result_packed["log_probs"], result_padded["log_probs"]) |
| 1934 | + |
| 1935 | + |
1828 | 1936 | if __name__ == "__main__":
|
1829 | 1937 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
1830 | 1938 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments