|
1 | 1 | import logging
|
2 | 2 | import re
|
3 |
| -from typing import List, Optional |
| 3 | +from typing import Any, List, Optional |
4 | 4 |
|
5 | 5 | from langchain.chains import LLMChain
|
6 | 6 | from langchain.chains.prompt_selector import ConditionalPromptSelector
|
@@ -81,6 +81,35 @@ class WebResearchRetriever(BaseRetriever):
|
81 | 81 | "check .netrc for proxy configuration",
|
82 | 82 | )
|
83 | 83 |
|
| 84 | + allow_dangerous_requests: bool = False |
| 85 | + """A flag to force users to acknowledge the risks of SSRF attacks when using |
| 86 | + this retriever. |
| 87 | + |
| 88 | + Users should set this flag to `True` if they have taken the necessary precautions |
| 89 | + to prevent SSRF attacks when using this retriever. |
| 90 | + |
| 91 | + For example, users can run the requests through a properly configured |
| 92 | + proxy and prevent the crawler from accidentally crawling internal resources. |
| 93 | + """ |
| 94 | + |
| 95 | + def __init__(self, **kwargs: Any) -> None: |
| 96 | + """Initialize the retriever.""" |
| 97 | + allow_dangerous_requests = kwargs.get("allow_dangerous_requests", False) |
| 98 | + if not allow_dangerous_requests: |
| 99 | + raise ValueError( |
| 100 | + "WebResearchRetriever crawls URLs surfaced through " |
| 101 | + "the provided search engine. It is possible that some of those URLs " |
| 102 | + "will end up pointing to machines residing on an internal network, " |
| 103 | + "leading" |
| 104 | + "to an SSRF (Server-Side Request Forgery) attack. " |
| 105 | + "To protect yourself against that risk, you can run the requests " |
| 106 | + "through a proxy and prevent the crawler from accidentally crawling " |
| 107 | + "internal resources." |
| 108 | + "If've taken the necessary precautions, you can set " |
| 109 | + "`allow_dangerous_requests` to `True`." |
| 110 | + ) |
| 111 | + super().__init__(**kwargs) |
| 112 | + |
84 | 113 | @classmethod
|
85 | 114 | def from_llm(
|
86 | 115 | cls,
|
|
0 commit comments