diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e38a722..bdd0f54 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,12 @@ Changelog .. This document is user facing. Please word the changes in such a way .. that users understand how the changes affect the new version. +version 0.5.1 +----------------- ++ Fix a bug where ``gzip_ng_threaded.open`` could + cause a hang when the program exited and the program was not used with a + context manager. + version 0.5.0 ----------------- + Wheels are now build for MacOS arm64 architectures. diff --git a/src/zlib_ng/gzip_ng_threaded.py b/src/zlib_ng/gzip_ng_threaded.py index 5b8a9ff..8a7e54a 100644 --- a/src/zlib_ng/gzip_ng_threaded.py +++ b/src/zlib_ng/gzip_ng_threaded.py @@ -100,8 +100,7 @@ def __init__(self, filename, queue_size=2, block_size=1024 * 1024): self.block_size = block_size self.worker = threading.Thread(target=self._decompress) self._closed = False - self.running = True - self.worker.start() + self.running = False def _check_closed(self, msg=None): if self._closed: @@ -125,8 +124,19 @@ def _decompress(self): except queue.Full: pass + def _start(self): + if not self.running: + self.running = True + self.worker.start() + + def _stop(self): + if self.running: + self.running = False + self.worker.join() + def readinto(self, b): self._check_closed() + self._start() result = self.buffer.readinto(b) if result == 0: while True: @@ -154,8 +164,7 @@ def tell(self) -> int: def close(self) -> None: if self._closed: return - self.running = False - self.worker.join() + self._stop() self.fileobj.close() if self.closefd: self.raw.close() @@ -252,7 +261,6 @@ def __init__(self, self.raw, self.closefd = open_as_binary_stream(filename, mode) self._closed = False self._write_gzip_header() - self.start() def _check_closed(self, msg=None): if self._closed: @@ -275,21 +283,24 @@ def _write_gzip_header(self): self.raw.write(struct.pack( "BBBBIBB", magic1, magic2, method, flags, mtime, os, xfl)) - def start(self): - self.running = True - self.output_worker.start() - for worker in self.compression_workers: - worker.start() + def _start(self): + if not self.running: + self.running = True + self.output_worker.start() + for worker in self.compression_workers: + worker.start() def stop(self): """Stop, but do not care for remaining work""" - self.running = False - for worker in self.compression_workers: - worker.join() - self.output_worker.join() + if self.running: + self.running = False + for worker in self.compression_workers: + worker.join() + self.output_worker.join() def write(self, b) -> int: self._check_closed() + self._start() with self.lock: if self.exception: raise self.exception diff --git a/tests/test_gzip_ng_threaded.py b/tests/test_gzip_ng_threaded.py index 1032a67..1a0a5a8 100644 --- a/tests/test_gzip_ng_threaded.py +++ b/tests/test_gzip_ng_threaded.py @@ -9,6 +9,8 @@ import io import itertools import os +import subprocess +import sys import tempfile from pathlib import Path @@ -103,6 +105,7 @@ def test_threaded_write_error(threads): threads=threads, block_size=8 * 1024) # Bypass the write method which should not allow blocks larger than # block_size. + f._start() f.input_queues[0].put((os.urandom(1024 * 64), b"")) with pytest.raises(OverflowError) as error: f.close() @@ -209,3 +212,22 @@ def test_threaded_writer_does_not_close_stream(): assert not test_stream.closed test_stream.seek(0) assert gzip.decompress(test_stream.read()) == b"thisisatest" + + +@pytest.mark.timeout(5) +@pytest.mark.parametrize( + ["mode", "threads"], itertools.product(["rb", "wb"], [1, 2])) +def test_threaded_program_can_exit_on_error(tmp_path, mode, threads): + program = tmp_path / "no_context_manager.py" + test_file = tmp_path / "output.gz" + # Write 40 mb input data to saturate read buffer. Because of the repetitive + # nature the resulting gzip file is very small (~40 KiB). + test_file.write_bytes(gzip.compress(b"test" * (10 * 1024 * 1024))) + with open(program, "wt") as f: + f.write("from zlib_ng import gzip_ng_threaded\n") + f.write( + f"f = gzip_ng_threaded.open('{test_file}', " + f"mode='{mode}', threads={threads})\n" + ) + f.write("raise Exception('Error')\n") + subprocess.run([sys.executable, str(program)])