diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 50153f8d..10352bd2 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -916,7 +916,8 @@ def create_datagram_endpoint(self, protocol_factory, if local_addr: sock.bind(local_address) if remote_addr: - yield from self.sock_connect(sock, remote_address) + if not allow_broadcast: + yield from self.sock_connect(sock, remote_address) r_addr = remote_address except OSError as exc: if sock is not None: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 12d357b5..f1044431 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -553,7 +553,10 @@ class _SelectorTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, extra=None, server=None): super().__init__(extra, loop) self._extra['socket'] = sock - self._extra['sockname'] = sock.getsockname() + try: + self._extra['sockname'] = sock.getsockname() + except socket.error: + self._extra['sockname'] = None if 'peername' not in self._extra: try: self._extra['peername'] = sock.getpeername() @@ -1083,9 +1086,11 @@ def sendto(self, data, addr=None): if not data: return - if self._address and addr not in (None, self._address): - raise ValueError('Invalid address: must be None or %s' % - (self._address,)) + if self._address: + if addr not in (None, self._address): + raise ValueError( + 'Invalid address: must be None or %s' % (self._address,)) + addr = self._address if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: @@ -1096,10 +1101,7 @@ def sendto(self, data, addr=None): if not self._buffer: # Attempt to send it right away first. try: - if self._address: - self._sock.send(data) - else: - self._sock.sendto(data, addr) + self._sock.sendto(data, addr) return except (BlockingIOError, InterruptedError): self._loop._add_writer(self._sock_fd, self._sendto_ready) @@ -1119,10 +1121,7 @@ def _sendto_ready(self): while self._buffer: data, addr = self._buffer.popleft() try: - if self._address: - self._sock.send(data) - else: - self._sock.sendto(data, addr) + self._sock.sendto(data, addr) except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 3f1ec651..9f293ba9 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -1463,6 +1463,23 @@ def test_create_datagram_endpoint_connect_err(self): self.assertRaises( OSError, self.loop.run_until_complete, coro) + def test_create_datagram_endpoint_allow_broadcast(self): + protocol = MyDatagramProto(create_future=True, loop=self.loop) + self.loop.sock_connect = sock_connect = mock.Mock() + sock_connect.return_value = [] + + coro = self.loop.create_datagram_endpoint( + lambda: protocol, + remote_addr=('127.0.0.1', 0), + allow_broadcast=True) + + transport, _ = self.loop.run_until_complete(coro) + self.assertFalse(sock_connect.called) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + @patch_socket def test_create_datagram_endpoint_socket_err(self, m_socket): m_socket.getaddrinfo = socket.getaddrinfo diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 6bf7862e..be1efa92 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1673,7 +1673,7 @@ def test_sendto_error_received(self): def test_sendto_error_received_connected(self): data = b'data' - self.sock.send.side_effect = ConnectionRefusedError + self.sock.sendto.side_effect = ConnectionRefusedError transport = self.datagram_transport(address=('0.0.0.0', 1)) transport._fatal_error = mock.Mock() @@ -1768,7 +1768,7 @@ def test_sendto_ready_error_received(self): self.assertFalse(transport._fatal_error.called) def test_sendto_ready_error_received_connection(self): - self.sock.send.side_effect = ConnectionRefusedError + self.sock.sendto.side_effect = ConnectionRefusedError transport = self.datagram_transport(address=('0.0.0.0', 1)) transport._fatal_error = mock.Mock()