diff --git a/spec/python/extra/decorators.py b/spec/python/extra/decorators.py new file mode 100644 index 000000000..0e0c4c099 --- /dev/null +++ b/spec/python/extra/decorators.py @@ -0,0 +1,88 @@ +import io +import itertools +from functools import wraps +from types import FunctionType + +from kaitaistruct import PY2 + +from helpers import FileOpenMode, FileOpenConfig, TemporaryFile, InMemoryStream + + +def stream_param_tests(cls): + # From https://github.com/adamchainz/unittest-parametrize/blob/58cf079/src/unittest_parametrize/__init__.py#L24-L28 + for name, func in list(cls.__dict__.items()): + if not isinstance(func, FunctionType): + continue + if not getattr(func, '_parametrized_by_write_stream', False): + continue + + delattr(cls, name) + + def test_builder(stream_builder, func=func): + @wraps(func) + def wrapper(self): + return func(self, stream_builder) + return wrapper + + for test_method in _generate_write_subtests(name, test_builder): + setattr(cls, test_method.__name__, test_method) + + return cls + + +def write_stream_param(func): + func._parametrized_by_write_stream = True + return func + + +def _generate_write_subtests(test_basename, test_builder): + use_builtin_open_options = (False,) + # in Python 3, built-in open() is an alias of io.open(); in Python 2, open() is + # a different function than io.open() + builtin_open_is_not_io_open = open is not io.open + assert builtin_open_is_not_io_open == PY2, ( + "expected the built-in open() to be different from io.open() only in Python 2, " + "but builtin_open_is_not_io_open={} and PY2={}" + .format(builtin_open_is_not_io_open, PY2) + ) + + if builtin_open_is_not_io_open: + use_builtin_open_options += (True,) + open_modes = [mode for mode in FileOpenMode if mode.writable] + buffering_options = (-1, 0) + + generated_test_methods = [] + + # NOTE: We're not using unittest.TestCase.subTest + # (https://docs.python.org/3/library/unittest.html#unittest.TestCase.subTest) for two + # reasons: + # 1. it's only available in Python 3, not Python 2 + # 2. the "unittest-xml-reporting" Python package that we use to generate XML test + # reports has a limited support for subtests, see + # https://github.com/xmlrunner/unittest-xml-reporting/tree/3.2.0#limited-support-for-unittesttestcasesubtest + for use_builtin_open, open_mode, buffering in itertools.product( + use_builtin_open_options, open_modes, buffering_options + ): + open_conf = FileOpenConfig(buffering, use_builtin_open) + + def stream_builder(orig_io_size, open_mode=open_mode, open_conf=open_conf): + return TemporaryFile(open_mode, open_conf, orig_io_size) + + test_method = test_builder(stream_builder) + test_method.__name__ = \ + '{}__TemporaryFile_{}open_{}{}'.format( + test_basename, + 'builtin_' if use_builtin_open else 'io_', + open_mode.ident, + '_nobuf' if buffering == 0 else '', + ) + generated_test_methods.append(test_method) + + def stream_builder(orig_io_size): + return InMemoryStream.from_size(orig_io_size) + + test_method = test_builder(stream_builder) + test_method.__name__ = '{}__InMemoryStream'.format(test_basename) + generated_test_methods.append(test_method) + + return generated_test_methods diff --git a/spec/python/extra/helpers.py b/spec/python/extra/helpers.py new file mode 100644 index 000000000..948d5de6f --- /dev/null +++ b/spec/python/extra/helpers.py @@ -0,0 +1,232 @@ +import abc +import io +import os +import shutil +import sys +import tempfile +from contextlib import contextmanager +from enum import Enum + +from kaitaistruct import KaitaiStream + +# See https://stackoverflow.com/a/41622155 +if sys.version_info >= (3, 4): + ABC = abc.ABC +else: + # See https://stackoverflow.com/a/38668373 + ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) + + +class FileOpenMode(Enum): + READ_ONLY = 1 + WRITE_ONLY = 2 + READ_WRITE = 3 + + @property + def readable(self): + return self in {FileOpenMode.READ_ONLY, FileOpenMode.READ_WRITE} + + @property + def writable(self): + return self in {FileOpenMode.WRITE_ONLY, FileOpenMode.READ_WRITE} + + @property + def mode(self): + if self == FileOpenMode.READ_ONLY: + return 'rb' + if self == FileOpenMode.WRITE_ONLY: + return 'wb' + if self == FileOpenMode.READ_WRITE: + return 'wb+' + + return None + + @property + def ident(self): + if self == FileOpenMode.READ_ONLY: + return 'rdonly' + if self == FileOpenMode.WRITE_ONLY: + return 'wronly' + if self == FileOpenMode.READ_WRITE: + return 'rdwr' + + return None + + +class FileOpenConfig(object): + def __init__(self, buffering, use_builtin_open): + self.buffering = buffering + self.use_builtin_open = use_builtin_open + + +class RegularFileOpener(object): + @staticmethod + def open_file_object(f, size): + try: + if size is not None: + f.truncate(size) + + return KaitaiStream(f) + except Exception: + f.close() + raise + + @classmethod + def open_path(cls, file_path, open_mode, open_conf, size=None): + if open_conf.use_builtin_open: + f = open(file_path, open_mode.mode, buffering=open_conf.buffering) + else: + f = io.open(file_path, open_mode.mode, buffering=open_conf.buffering) + + return cls.open_file_object(f, size) + + @classmethod + def open_fd(cls, fd, open_mode, open_conf, size=None): + try: + if open_conf.use_builtin_open: + # must use positional arguments (otherwise we get "TypeError: fdopen() takes no + # keyword arguments" in Python 2) + f = os.fdopen(fd, open_mode.mode, open_conf.buffering) + else: + f = io.open(fd, open_mode.mode, buffering=open_conf.buffering) + except Exception: + os.close(fd) + raise + + return cls.open_file_object(f, size) + + +class AbstractStream(ABC): + @abc.abstractmethod + def open(self): + pass + + @abc.abstractmethod + def open_as_read_only(self): + pass + + +class TemporaryFile(AbstractStream): + def __init__(self, open_mode, open_conf, size): + self.open_mode = open_mode + self.open_conf = open_conf + self.size = size + + self.tmp_path = None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() + + def open(self): + tmp_fd, self.tmp_path = tempfile.mkstemp() + return RegularFileOpener.open_fd(tmp_fd, self.open_mode, self.open_conf, self.size) + + def open_as_read_only(self): + if self.tmp_path is None: + raise ValueError( + "{}() must be called first" + .format(self.open.__name__) + ) + + return RegularFileOpener.open_path(self.tmp_path, FileOpenMode.READ_ONLY, self.open_conf) + + def destroy(self): + if self.tmp_path is not None: + tmp_path = self.tmp_path + self.tmp_path = None + os.remove(tmp_path) + + +class Pipe(AbstractStream): + def __init__(self, open_conf): + self.open_conf = open_conf + + self.r_fd = None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() + + def init_from_path(self, file_path): + self.r_fd, w_fd = os.pipe() + with io.open(w_fd, 'wb') as dst_f, io.open(file_path, 'rb') as src_f: + shutil.copyfileobj(src_f, dst_f) + + def open(self): + """ + At the time of writing, the KaitaiStream class doesn't support writing to non-seekable + streams (like pipes), but this method is kept in case it's supported in some form in the + future. + """ + self.r_fd, w_fd = os.pipe() + return RegularFileOpener.open_fd(w_fd, FileOpenMode.WRITE_ONLY, self.open_conf) + + def open_as_read_only(self): + if self.r_fd is None: + raise ValueError( + "{}() or {}() must be called first" + .format(self.init_from_path.__name__, self.open.__name__) + ) + + r_fd = self.r_fd + self.r_fd = None + return RegularFileOpener.open_fd(r_fd, FileOpenMode.READ_ONLY, self.open_conf) + + def destroy(self): + if self.r_fd is not None: + r_fd = self.r_fd + self.r_fd = None + os.close(r_fd) + + +class InMemoryStream(AbstractStream): + def __init__(self, f): + self.ks_io = KaitaiStream(f) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() + + @classmethod + def from_size(cls, size): + f = io.BytesIO() + if size > 0: + f.seek(size - 1, io.SEEK_SET) + f.write(b'\x00') + f.seek(0, io.SEEK_SET) + + return cls(f) + + @classmethod + def from_path(cls, file_path): + f = io.BytesIO() + with io.open(file_path, 'rb') as src_f: + shutil.copyfileobj(src_f, f) + f.seek(0, io.SEEK_SET) + + return cls(f) + + @contextmanager + def _open_contextmanager(self): + try: + yield self.ks_io + finally: + self.ks_io.seek(0) + + def open(self): + return self._open_contextmanager() + + def open_as_read_only(self): + return self.open() + + def destroy(self): + if self.ks_io is not None: + self.ks_io.close() + self.ks_io = None diff --git a/spec/python/extra/test_helpers.py b/spec/python/extra/test_helpers.py new file mode 100644 index 000000000..27c9a25d4 --- /dev/null +++ b/spec/python/extra/test_helpers.py @@ -0,0 +1,100 @@ +import os.path +import unittest + +from helpers import FileOpenConfig, FileOpenMode, InMemoryStream, Pipe, RegularFileOpener, \ + TemporaryFile + +TEST_FILE_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '../../../src/nav_parent_switch.bin' +) + + +class TestHelpers(unittest.TestCase): + def test_RegularFileOpener_open_path(self): + open_mode = FileOpenMode.READ_ONLY + open_conf = FileOpenConfig(-1, False) + with RegularFileOpener.open_path(TEST_FILE_PATH, open_mode, open_conf) as ks_io: + self.assertEqual(ks_io.read_bytes_full(), b'\x01\x42\xff') + + def test_TemporaryFile(self): + open_conf = FileOpenConfig(-1, False) + with TemporaryFile(FileOpenMode.WRITE_ONLY, open_conf, 6) as tf: + with tf.open() as ks_io: + self.assertEqual(ks_io.size(), 6) + ks_io.write_bytes(b'Hello\0') + + file_path = tf.tmp_path + self.assertTrue(os.path.isfile(file_path)) + + with tf.open_as_read_only() as ks_io: + self.assertEqual(ks_io.read_bytes_full(), b'Hello\0') + + self.assertFalse(os.path.lexists(file_path)) + + def test_TemporaryFile_no_usage(self): + """ Ensure that we don't get any errors (that could be potentially raised from __exit__() + which is called automatically) if we create a TemporaryFile but do nothing with it. + """ + open_conf = FileOpenConfig(-1, False) + with TemporaryFile(FileOpenMode.WRITE_ONLY, open_conf, 6) as _tf: + pass + + def test_TemporaryFile_uncaught_exception_before_open(self): + """ Ensure that we don't get any errors (that could be potentially raised from __exit__() + which is called automatically) if we create a TemporaryFile but then there's an uncaught + exception before we call open(). + """ + open_conf = FileOpenConfig(-1, False) + with self.assertRaises(ValueError) as cm: + with TemporaryFile(FileOpenMode.WRITE_ONLY, open_conf, 6) as _tf: + raise ValueError("fooBarQux happened") + # here we would call open() etc. + + self.assertEqual(str(cm.exception), "fooBarQux happened") + + def test_TemporaryFile_open_as_read_only_too_soon(self): + open_conf = FileOpenConfig(-1, False) + with TemporaryFile(FileOpenMode.WRITE_ONLY, open_conf, 6) as tf: + with self.assertRaises(ValueError) as cm: + with tf.open_as_read_only() as _ks_io: + pass + self.assertEqual(str(cm.exception), "open() must be called first") + + def test_InMemoryStream_from_size(self): + with InMemoryStream.from_size(6) as ims: + with ims.open() as ks_io: + self.assertEqual(ks_io.size(), 6) + ks_io.write_bytes(b'Hello\0') + with ims.open_as_read_only() as ks_io: + self.assertEqual(ks_io.read_bytes_full(), b'Hello\0') + + def test_InMemoryStream_from_path(self): + with InMemoryStream.from_path(TEST_FILE_PATH) as ims: + with ims.open() as ks_io: + self.assertEqual(ks_io.read_bytes_full(), b'\x01\x42\xff') + + def test_Pipe(self): + open_conf = FileOpenConfig(0, False) + with Pipe(open_conf) as p: + p.init_from_path(TEST_FILE_PATH) + with p.open_as_read_only() as ks_io: + self.assertEqual(ks_io.read_u1(), 0x01) + self.assertEqual(ks_io.read_bytes_full(), b'\x42\xff') + + def test_Pipe_no_usage(self): + open_conf = FileOpenConfig(0, False) + with Pipe(open_conf) as _p: + pass + + def test_Pipe_open_as_read_only_too_soon(self): + open_conf = FileOpenConfig(0, False) + with Pipe(open_conf) as p: + with self.assertRaises(ValueError) as cm: + with p.open_as_read_only() as _ks_io: + pass + self.assertEqual(str(cm.exception), "init_from_path() or open() must be called first") + + +if __name__ == '__main__': + unittest.main() diff --git a/spec/python/specwrite/common_spec.py b/spec/python/specwrite/common_spec.py index 3f0341f0b..f77253ba6 100644 --- a/spec/python/specwrite/common_spec.py +++ b/spec/python/specwrite/common_spec.py @@ -1,37 +1,40 @@ import unittest -import io + from kaitaistruct import KaitaiStream, KaitaiStruct, PY2 +from decorators import stream_param_tests, write_stream_param + + # A little hack from https://stackoverflow.com/a/25695512 to trick 'unittest' # into thinking that CommonSpec.Base is not a test by itself. -class CommonSpec: - +class CommonSpec(object): + @stream_param_tests class Base(unittest.TestCase): def __init__(self, *args, **kwargs): super(CommonSpec.Base, self).__init__(*args, **kwargs) self.maxDiff = None + self.skip_roundtrip_msg_reason = None - def test_read_write_roundtrip(self): - orig_f = io.open(self.src_filename, 'rb') + @write_stream_param + def test_read_write_roundtrip(self, stream_builder): + if self.skip_roundtrip_msg_reason is not None: + self.skipTest(self.skip_roundtrip_msg_reason) - try: - orig_ks = self.struct_class.from_io(orig_f) + with self.struct_class.from_file(self.src_filename) as orig_ks: orig_ks._read() - orig_dump = CommonSpec.Base.dump_struct(orig_ks) - orig_io_size = orig_ks._io.size() - finally: - orig_f.close() - with KaitaiStream(io.BytesIO(bytearray(orig_io_size))) as new_io: - orig_ks._write(new_io) - new_io.seek(0) + with stream_builder(orig_io_size) as obj: + with obj.open() as ks_io: + self.assertEqual(ks_io.size(), orig_io_size) + orig_ks._write(ks_io) - new_ks = self.struct_class(new_io) - new_ks._read() + with obj.open_as_read_only() as ks_io: + new_ks = self.struct_class(ks_io) + new_ks._read() - new_dump = CommonSpec.Base.dump_struct(new_ks) + new_dump = CommonSpec.Base.dump_struct(new_ks) self.assertEqual(orig_dump, new_dump) diff --git a/spec/python/specwrite/test_default_endian_expr_exception.py b/spec/python/specwrite/test_default_endian_expr_exception.py index 61412decb..d0091fa69 100644 --- a/spec/python/specwrite/test_default_endian_expr_exception.py +++ b/spec/python/specwrite/test_default_endian_expr_exception.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestDefaultEndianExprException, self).__init__(*args, **kwargs) self.struct_class = DefaultEndianExprException self.src_filename = 'src/endian_expr.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_eof_exception_bits_be.py b/spec/python/specwrite/test_eof_exception_bits_be.py index 772787486..28399fbd7 100644 --- a/spec/python/specwrite/test_eof_exception_bits_be.py +++ b/spec/python/specwrite/test_eof_exception_bits_be.py @@ -1,26 +1,25 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_bits_be import EofExceptionBitsBe +@stream_param_tests class TestEofExceptionBitsBe(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionBitsBe, self).__init__(*args, **kwargs) self.struct_class = EofExceptionBitsBe self.src_filename = 'src/nav_parent_switch.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_bits_be(self): + @write_stream_param + def test_eof_exception_bits_be(self, stream_builder): r = EofExceptionBitsBe() r.pre_bits = 0b0000000 # 0b0000_000 r.fail_bits = 0b101000010111111110 # 0b1_01000010_11111111_0 r._check() - with KaitaiStream(io.BytesIO(bytearray(3))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): - r._write__seq(out_io) - self.assertEqual(out_io.pos(), 1) + with stream_builder(3) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): + r._write__seq(ks_io) + self.assertEqual(ks_io.pos(), 1) diff --git a/spec/python/specwrite/test_eof_exception_bits_be2.py b/spec/python/specwrite/test_eof_exception_bits_be2.py index ae6868b9f..66305f4a2 100644 --- a/spec/python/specwrite/test_eof_exception_bits_be2.py +++ b/spec/python/specwrite/test_eof_exception_bits_be2.py @@ -1,26 +1,25 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_bits_be2 import EofExceptionBitsBe2 +@stream_param_tests class TestEofExceptionBitsBe2(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionBitsBe2, self).__init__(*args, **kwargs) self.struct_class = EofExceptionBitsBe2 self.src_filename = 'src/nav_parent_switch.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_bits_be2(self): + @write_stream_param + def test_eof_exception_bits_be2(self, stream_builder): r = EofExceptionBitsBe2() r.pre_bits = 0x01 r.fail_bits = 0b01000010111111110 # 0b01000010_11111111_0 r._check() - with KaitaiStream(io.BytesIO(bytearray(3))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): - r._write__seq(out_io) - self.assertEqual(out_io.pos(), 1) + with stream_builder(3) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): + r._write__seq(ks_io) + self.assertEqual(ks_io.pos(), 1) diff --git a/spec/python/specwrite/test_eof_exception_bits_le.py b/spec/python/specwrite/test_eof_exception_bits_le.py index b7fa601d5..8539a473d 100644 --- a/spec/python/specwrite/test_eof_exception_bits_le.py +++ b/spec/python/specwrite/test_eof_exception_bits_le.py @@ -1,26 +1,25 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_bits_le import EofExceptionBitsLe +@stream_param_tests class TestEofExceptionBitsLe(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionBitsLe, self).__init__(*args, **kwargs) self.struct_class = EofExceptionBitsLe self.src_filename = 'src/nav_parent_switch.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_bits_le(self): + @write_stream_param + def test_eof_exception_bits_le(self, stream_builder): r = EofExceptionBitsLe() r.pre_bits = 0b0000001 # 0b000_0001 r.fail_bits = 0b011111111010000100 # 0b0_11111111_01000010_0 r._check() - with KaitaiStream(io.BytesIO(bytearray(3))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): - r._write__seq(out_io) - self.assertEqual(out_io.pos(), 1) + with stream_builder(3) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): + r._write__seq(ks_io) + self.assertEqual(ks_io.pos(), 1) diff --git a/spec/python/specwrite/test_eof_exception_bits_le2.py b/spec/python/specwrite/test_eof_exception_bits_le2.py index 3e1b0cec7..82d19a058 100644 --- a/spec/python/specwrite/test_eof_exception_bits_le2.py +++ b/spec/python/specwrite/test_eof_exception_bits_le2.py @@ -1,26 +1,25 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_bits_le2 import EofExceptionBitsLe2 +@stream_param_tests class TestEofExceptionBitsLe2(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionBitsLe2, self).__init__(*args, **kwargs) self.struct_class = EofExceptionBitsLe2 self.src_filename = 'src/nav_parent_switch.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_bits_le2(self): + @write_stream_param + def test_eof_exception_bits_le2(self, stream_builder): r = EofExceptionBitsLe2() r.pre_bits = 0x01 r.fail_bits = 0b01111111101000010 # 0b0_11111111_01000010 r._check() - with KaitaiStream(io.BytesIO(bytearray(3))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): - r._write__seq(out_io) - self.assertEqual(out_io.pos(), 1) + with stream_builder(3) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 3 bytes, but only 2 bytes left in the stream$"): + r._write__seq(ks_io) + self.assertEqual(ks_io.pos(), 1) diff --git a/spec/python/specwrite/test_eof_exception_bytes.py b/spec/python/specwrite/test_eof_exception_bytes.py index 0d8d8280e..b2191a61e 100644 --- a/spec/python/specwrite/test_eof_exception_bytes.py +++ b/spec/python/specwrite/test_eof_exception_bytes.py @@ -1,24 +1,23 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_bytes import EofExceptionBytes +@stream_param_tests class TestEofExceptionBytes(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionBytes, self).__init__(*args, **kwargs) self.struct_class = EofExceptionBytes self.src_filename = 'src/term_strz.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_bytes(self): + @write_stream_param + def test_eof_exception_bytes(self, stream_builder): r = EofExceptionBytes() r.buf = b"\x78\x79\x7A\x7B\x7C\x7D\x7E\x7F\xFF\xFE\xFD\xFC\xFB" r._check() - with KaitaiStream(io.BytesIO(bytearray(12))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 13 bytes, but only 12 bytes left in the stream$"): - r._write(out_io) + with stream_builder(12) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 13 bytes, but only 12 bytes left in the stream$"): + r._write(ks_io) diff --git a/spec/python/specwrite/test_eof_exception_u4.py b/spec/python/specwrite/test_eof_exception_u4.py index cd1921bac..2f4a55038 100644 --- a/spec/python/specwrite/test_eof_exception_u4.py +++ b/spec/python/specwrite/test_eof_exception_u4.py @@ -1,25 +1,24 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eof_exception_u4 import EofExceptionU4 +@stream_param_tests class TestEofExceptionU4(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEofExceptionU4, self).__init__(*args, **kwargs) self.struct_class = EofExceptionU4 self.src_filename = 'src/term_strz.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eof_exception_u4(self): + @write_stream_param + def test_eof_exception_u4(self, stream_builder): r = EofExceptionU4() r.prebuf = b"\x78\x79\x7A\x7B\x7C\x7D\x7E\x7F\x80" r.fail_int = 3000500200 r._check() - with KaitaiStream(io.BytesIO(bytearray(12))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 4 bytes, but only 3 bytes left in the stream$"): - r._write(out_io) + with stream_builder(12) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 4 bytes, but only 3 bytes left in the stream$"): + r._write(ks_io) diff --git a/spec/python/specwrite/test_eos_exception_bytes.py b/spec/python/specwrite/test_eos_exception_bytes.py index dd84f731e..0a5f82f10 100644 --- a/spec/python/specwrite/test_eos_exception_bytes.py +++ b/spec/python/specwrite/test_eos_exception_bytes.py @@ -1,20 +1,18 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eos_exception_bytes import EosExceptionBytes +@stream_param_tests class TestEosExceptionBytes(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEosExceptionBytes, self).__init__(*args, **kwargs) self.struct_class = EosExceptionBytes self.src_filename = 'src/term_strz.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eos_exception_bytes(self): + @write_stream_param + def test_eos_exception_bytes(self, stream_builder): r = EosExceptionBytes() data = EosExceptionBytes.Data(None, r, r._root) @@ -24,6 +22,7 @@ def test_eos_exception_bytes(self): r.envelope = data r._check() - with KaitaiStream(io.BytesIO(bytearray(12))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 7 bytes, but only 6 bytes left in the stream$"): - r._write(out_io) + with stream_builder(12) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 7 bytes, but only 6 bytes left in the stream$"): + r._write(ks_io) diff --git a/spec/python/specwrite/test_eos_exception_u4.py b/spec/python/specwrite/test_eos_exception_u4.py index 0da977f84..619b7be86 100644 --- a/spec/python/specwrite/test_eos_exception_u4.py +++ b/spec/python/specwrite/test_eos_exception_u4.py @@ -1,20 +1,18 @@ -import unittest -import io -from kaitaistruct import KaitaiStream +from decorators import stream_param_tests, write_stream_param from specwrite.common_spec import CommonSpec from testwrite.eos_exception_u4 import EosExceptionU4 +@stream_param_tests class TestEosExceptionU4(CommonSpec.Base): def __init__(self, *args, **kwargs): super(TestEosExceptionU4, self).__init__(*args, **kwargs) self.struct_class = EosExceptionU4 self.src_filename = 'src/term_strz.bin' + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") - - def test_eos_exception_u4(self): + @write_stream_param + def test_eos_exception_u4(self, stream_builder): r = EosExceptionU4() data = EosExceptionU4.Data(None, r, r._root) @@ -25,6 +23,7 @@ def test_eos_exception_u4(self): r.envelope = data r._check() - with KaitaiStream(io.BytesIO(bytearray(12))) as out_io: - with self.assertRaisesRegexp(EOFError, u"^requested to write 4 bytes, but only 3 bytes left in the stream$"): - r._write(out_io) + with stream_builder(12) as obj: + with obj.open() as ks_io: + with self.assertRaisesRegexp(EOFError, u"^requested to write 4 bytes, but only 3 bytes left in the stream$"): + r._write(ks_io) diff --git a/spec/python/specwrite/test_valid_fail_anyof_int.py b/spec/python/specwrite/test_valid_fail_anyof_int.py index 2c3ea0339..4740aac04 100644 --- a/spec/python/specwrite/test_valid_fail_anyof_int.py +++ b/spec/python/specwrite/test_valid_fail_anyof_int.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailAnyofInt, self).__init__(*args, **kwargs) self.struct_class = ValidFailAnyofInt self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_contents.py b/spec/python/specwrite/test_valid_fail_contents.py index a9b3bf330..080951764 100644 --- a/spec/python/specwrite/test_valid_fail_contents.py +++ b/spec/python/specwrite/test_valid_fail_contents.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailContents, self).__init__(*args, **kwargs) self.struct_class = ValidFailContents self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_eq_bytes.py b/spec/python/specwrite/test_valid_fail_eq_bytes.py index 1d47d14be..93e7995e9 100644 --- a/spec/python/specwrite/test_valid_fail_eq_bytes.py +++ b/spec/python/specwrite/test_valid_fail_eq_bytes.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailEqBytes, self).__init__(*args, **kwargs) self.struct_class = ValidFailEqBytes self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_eq_int.py b/spec/python/specwrite/test_valid_fail_eq_int.py index 465243ff1..9a4f9f0f1 100644 --- a/spec/python/specwrite/test_valid_fail_eq_int.py +++ b/spec/python/specwrite/test_valid_fail_eq_int.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailEqInt, self).__init__(*args, **kwargs) self.struct_class = ValidFailEqInt self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_eq_str.py b/spec/python/specwrite/test_valid_fail_eq_str.py index 0c781ce4a..9c00184f7 100644 --- a/spec/python/specwrite/test_valid_fail_eq_str.py +++ b/spec/python/specwrite/test_valid_fail_eq_str.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailEqStr, self).__init__(*args, **kwargs) self.struct_class = ValidFailEqStr self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_expr.py b/spec/python/specwrite/test_valid_fail_expr.py index 1786cc74a..c4c6e5482 100644 --- a/spec/python/specwrite/test_valid_fail_expr.py +++ b/spec/python/specwrite/test_valid_fail_expr.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailExpr, self).__init__(*args, **kwargs) self.struct_class = ValidFailExpr self.src_filename = 'src/nav_parent_switch.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_inst.py b/spec/python/specwrite/test_valid_fail_inst.py index ef5b9ff52..5f4e71d60 100644 --- a/spec/python/specwrite/test_valid_fail_inst.py +++ b/spec/python/specwrite/test_valid_fail_inst.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailInst, self).__init__(*args, **kwargs) self.struct_class = ValidFailInst self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_max_int.py b/spec/python/specwrite/test_valid_fail_max_int.py index ecd3d25fe..fadd5daf4 100644 --- a/spec/python/specwrite/test_valid_fail_max_int.py +++ b/spec/python/specwrite/test_valid_fail_max_int.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailMaxInt, self).__init__(*args, **kwargs) self.struct_class = ValidFailMaxInt self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_min_int.py b/spec/python/specwrite/test_valid_fail_min_int.py index 9e607db12..9d1512392 100644 --- a/spec/python/specwrite/test_valid_fail_min_int.py +++ b/spec/python/specwrite/test_valid_fail_min_int.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailMinInt, self).__init__(*args, **kwargs) self.struct_class = ValidFailMinInt self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_range_bytes.py b/spec/python/specwrite/test_valid_fail_range_bytes.py index c3c3f94d4..5e8e24cd9 100644 --- a/spec/python/specwrite/test_valid_fail_range_bytes.py +++ b/spec/python/specwrite/test_valid_fail_range_bytes.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailRangeBytes, self).__init__(*args, **kwargs) self.struct_class = ValidFailRangeBytes self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_range_float.py b/spec/python/specwrite/test_valid_fail_range_float.py index 644646ccd..ddca6bbdb 100644 --- a/spec/python/specwrite/test_valid_fail_range_float.py +++ b/spec/python/specwrite/test_valid_fail_range_float.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailRangeFloat, self).__init__(*args, **kwargs) self.struct_class = ValidFailRangeFloat self.src_filename = 'src/floating_points.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_range_int.py b/spec/python/specwrite/test_valid_fail_range_int.py index 86c2c40ee..e81648737 100644 --- a/spec/python/specwrite/test_valid_fail_range_int.py +++ b/spec/python/specwrite/test_valid_fail_range_int.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailRangeInt, self).__init__(*args, **kwargs) self.struct_class = ValidFailRangeInt self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/spec/python/specwrite/test_valid_fail_range_str.py b/spec/python/specwrite/test_valid_fail_range_str.py index 4611e660f..187993238 100644 --- a/spec/python/specwrite/test_valid_fail_range_str.py +++ b/spec/python/specwrite/test_valid_fail_range_str.py @@ -10,6 +10,4 @@ def __init__(self, *args, **kwargs): super(TestValidFailRangeStr, self).__init__(*args, **kwargs) self.struct_class = ValidFailRangeStr self.src_filename = 'src/fixed_struct.bin' - - def test_read_write_roundtrip(self): - self.skipTest("cannot use roundtrip because parsing is expected to fail") + self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail" diff --git a/translator/src/main/scala/io/kaitai/struct/testtranslator/TestTranslator.scala b/translator/src/main/scala/io/kaitai/struct/testtranslator/TestTranslator.scala index 53674817e..626553b14 100644 --- a/translator/src/main/scala/io/kaitai/struct/testtranslator/TestTranslator.scala +++ b/translator/src/main/scala/io/kaitai/struct/testtranslator/TestTranslator.scala @@ -106,7 +106,7 @@ class TestTranslator(options: CLIOptions) { origSpecs } - def getSG(lang: String, testSpec: TestSpec, provider: ClassTypeProvider): BaseGenerator = lang match { + def getSG(lang: String, testSpec: TestSpec, provider: ClassTypeProvider): SpecGenerator = lang match { case "construct" => new ConstructSG(testSpec, provider) case "cpp_stl_98" => new CppStlSG(testSpec, provider, CppRuntimeConfig().copyAsCpp98()) case "cpp_stl_11" => new CppStlSG(testSpec, provider, CppRuntimeConfig().copyAsCpp11()) diff --git a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/PythonWriteSG.scala b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/PythonWriteSG.scala index c6e73c824..d209a5d9c 100644 --- a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/PythonWriteSG.scala +++ b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/PythonWriteSG.scala @@ -7,7 +7,7 @@ import _root_.io.kaitai.struct.languages.PythonCompiler import _root_.io.kaitai.struct.testtranslator.{Main, TestAssert, TestSpec} import _root_.io.kaitai.struct.translators.PythonTranslator -class PythonWriteSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(spec) { +class PythonWriteSG(spec: TestSpec, provider: ClassTypeProvider) extends SpecGenerator { importList.add("import unittest") importList.add("from specwrite.common_spec import CommonSpec") @@ -18,7 +18,7 @@ class PythonWriteSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGen override def indentStr: String = " " - override def header(): Unit = { + override def run(): Unit = { out.puts out.puts(s"from testwrite.${spec.id} import $className") out.puts @@ -31,27 +31,16 @@ class PythonWriteSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGen out.puts(s"super($testClassName, self).__init__(*args, **kwargs)") out.puts(s"self.struct_class = $className") out.puts(s"self.src_filename = 'src/${spec.data}'") - out.dec + spec.exception match { + case None => + out.dec + out.puts + case Some(_) => + out.puts("""self.skip_roundtrip_msg_reason = "cannot use roundtrip because parsing is expected to fail"""") + out.dec + } } - override def runParse(): Unit = {} - - override def runParseExpectError(exception: KSError): Unit = { - out.puts - out.puts("def test_read_write_roundtrip(self):") - out.inc - out.puts("""self.skipTest("cannot use roundtrip because parsing is expected to fail")""") - out.dec - } - - override def footer(): Unit = {} - - override def runAsserts(): Unit = {} - - override def simpleAssert(check: TestAssert): Unit = ??? - override def nullAssert(actual: Ast.expr): Unit = ??? - override def trueArrayAssert(check: TestAssert, elType: DataType, elts: Seq[Ast.expr]): Unit = ??? - override def results: String = "# " + AUTOGEN_COMMENT + "\n\n" + super.results }