diff --git a/pdf2zh/arcfour.py b/pdf2zh/arcfour.py index cc78e3612..5e524dbf9 100644 --- a/pdf2zh/arcfour.py +++ b/pdf2zh/arcfour.py @@ -13,6 +13,8 @@ def __init__(self, key: Sequence[int]) -> None: s = [i for i in range(256)] j = 0 klen = len(key) + if not (0 < klen < 256): + raise ValueError("key must be non-empty and less than 256 bytes") for i in range(256): j = (j + s[i] + key[i % klen]) % 256 (s[i], s[j]) = (s[j], s[i]) diff --git a/pdf2zh/ascii85.py b/pdf2zh/ascii85.py index 233bc744a..2091aa47f 100644 --- a/pdf2zh/ascii85.py +++ b/pdf2zh/ascii85.py @@ -4,7 +4,6 @@ """ -import re import struct @@ -38,14 +37,13 @@ def ascii85decode(data: bytes) -> bytes: b = b * 85 + 84 out += struct.pack(">L", b)[: n - 1] break + elif c.isspace(): + continue + else: + raise ValueError("Bad character in ASCII85Decode") return out -# asciihexdecode(data) -hex_re = re.compile(rb"([a-f\d]{2})", re.IGNORECASE) -trail_re = re.compile(rb"^(?:[a-f\d]{2}|\s)*([a-f\d])[\s>]*$", re.IGNORECASE) - - def asciihexdecode(data: bytes) -> bytes: """ASCIIHexDecode filter: PDFReference v1.4 section 3.3.1 For each pair of ASCII hexadecimal digits (0-9 and A-F or a-f), the @@ -56,15 +54,17 @@ def asciihexdecode(data: bytes) -> bytes: will behave as if a 0 followed the last digit. """ - def decode(x: bytes) -> bytes: - i = int(x, 16) - return bytes((i,)) - - out = b"" - for x in hex_re.findall(data): - out += decode(x) - - m = trail_re.search(data) - if m: - out += decode(m.group(1) + b"0") - return out + hex_str = b"" + for i in data: + c = bytes((i,)) + if b"0" <= c <= b"9" or b"a" <= c <= b"f" or b"A" <= c <= b"F": + hex_str += c + elif c == b">": + break + elif c in b" \n\r\t": + continue + else: + raise ValueError("Bad character in ASCIIHexDecode") + if len(hex_str) % 2 == 1: + hex_str += b"0" + return bytes.fromhex(hex_str.decode()) diff --git a/pyproject.toml b/pyproject.toml index e95c3f2c5..28fce2115 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ torch = [ dev = [ "black", "flake8", - "pre-commit" + "pre-commit", + "pytest", ] [project.urls] diff --git a/tests/test_arcfour.py b/tests/test_arcfour.py new file mode 100644 index 000000000..39559fbc9 --- /dev/null +++ b/tests/test_arcfour.py @@ -0,0 +1,60 @@ +import pytest +from pdf2zh.arcfour import Arcfour + + +def arcfour_encrypt(plaintext: bytes, key: bytes) -> bytes: + cipher = Arcfour(key) + return cipher.encrypt(plaintext) + + +def arcfour_decrypt(ciphertext: bytes, key: bytes) -> bytes: + cipher = Arcfour(key) + return cipher.decrypt(ciphertext) + + +def test_basic_functionality(): + plaintext = b"Hello, World!" + key = b"mysecretkey" + ciphertext = arcfour_encrypt(plaintext, key) + decrypted_text = arcfour_decrypt(ciphertext, key) + assert decrypted_text == plaintext + +def test_ciphertext(): + plaintext = b"Hello, World!" + key = b"mysecretkey" + ciphertext = arcfour_encrypt(plaintext, key) + assert ciphertext.hex() == "a2b614d4af651ec7af7f5259db" + +def test_empty_plaintext(): + plaintext = b"" + key = b"testkey" + ciphertext = arcfour_encrypt(plaintext, key) + decrypted_text = arcfour_decrypt(ciphertext, key) + assert decrypted_text == plaintext + + +def test_empty_key(): + plaintext = b"Test data" + key = b"" + with pytest.raises(ValueError): # Key must be non-empty + arcfour_encrypt(plaintext, key) + + +def test_large_key(): + plaintext = b"Test data" + key = b"X" * 256 + with pytest.raises(ValueError): # Key must be less than 256 bytes + arcfour_encrypt(plaintext, key) + +def test_large_key2(): + plaintext = b"Test data" + key = b"X" * 255 + ciphertext = arcfour_encrypt(plaintext, key) + decrypted_text = arcfour_decrypt(ciphertext, key) + assert decrypted_text == plaintext + +def test_randomness(): + plaintext = b"AAAAA" + key = b"randomkey" + ciphertexts = [arcfour_encrypt(plaintext, key) for _ in range(10)] + assert len(set(ciphertexts)) <= 1 # All ciphertexts should be the same diff --git a/tests/test_ascii85.py b/tests/test_ascii85.py new file mode 100644 index 000000000..d1208a1e6 --- /dev/null +++ b/tests/test_ascii85.py @@ -0,0 +1,62 @@ +import pytest +from pdf2zh.ascii85 import ascii85decode, asciihexdecode + + +def test_ascii85decode_basic(): + encoded = b"87cURD_*#TDfTZ)+T" + decoded = ascii85decode(encoded + b"~>") + assert decoded == b"Hello, world!" + + +def test_ascii85decode_empty(): + encoded = b"" + decoded = ascii85decode(encoded) + assert decoded == b"" + + +def test_ascii85decode_with_z(): + encoded = b"z" + decoded = ascii85decode(encoded) + assert decoded == b"\x00\x00\x00\x00" + + +def test_ascii85decode_partial_group(): + encoded = b"9jqo^BlbD-BleB1DJ+*+F(f,q" # Encodes 'Man is distinguished' + decoded = ascii85decode(encoded + b"~>") + assert decoded.startswith(b"Man is distinguished") + + +def test_ascii85decode_with_termination(): + encoded = b"9jqo^~>" + decoded = ascii85decode(encoded) + assert decoded == b"Man " + + +def test_asciihexdecode_basic(): + encoded = b"48656C6C6F>" + decoded = asciihexdecode(encoded) + assert decoded == b"Hello" + + +def test_asciihexdecode_with_whitespace(): + encoded = b"48 65 6C 6C 6F >" + decoded = asciihexdecode(encoded) + assert decoded == b"Hello" + + +def test_asciihexdecode_odd_length(): + encoded = b"48656C6C6F3" + decoded = asciihexdecode(encoded) + assert decoded == b"Hello0" + + +def test_asciihexdecode_empty(): + encoded = b"" + decoded = asciihexdecode(encoded) + assert decoded == b"" + + +def test_asciihexdecode_invalid_characters(): + encoded = b"ZZ>" + with pytest.raises(ValueError): + asciihexdecode(encoded) diff --git a/tests/test_casting.py b/tests/test_casting.py new file mode 100644 index 000000000..610310c06 --- /dev/null +++ b/tests/test_casting.py @@ -0,0 +1,59 @@ +from pdf2zh.casting import safe_int, safe_float + + +def test_safe_int_valid_string(): + assert safe_int("123") == 123 + + +def test_safe_int_valid_integer(): + assert safe_int(123) == 123 + + +def test_safe_int_valid_float_string(): + assert safe_int("123.45") is None + + +def test_safe_int_invalid_string(): + assert safe_int("abc") is None + + +def test_safe_int_none(): + assert safe_int(None) is None + + +def test_safe_int_boolean(): + assert safe_int(True) == 1 + assert safe_int(False) == 0 + + +def test_safe_int_special_chars(): + assert safe_int("!@#") is None + + +def test_safe_float_valid_float(): + assert safe_float("123.45") == 123.45 + + +def test_safe_float_valid_integer_string(): + assert safe_float("123") == 123.0 + + +def test_safe_float_scientific_notation(): + assert safe_float("1.23e-4") == 0.000123 + + +def test_safe_float_invalid_string(): + assert safe_float("abc") is None + + +def test_safe_float_none(): + assert safe_float(None) is None + + +def test_safe_float_boolean(): + assert safe_float(True) == 1.0 + assert safe_float(False) == 0.0 + + +def test_safe_float_special_chars(): + assert safe_float("!@#") is None diff --git a/tests/test_ccitt.py b/tests/test_ccitt.py new file mode 100644 index 000000000..ebdc634fc --- /dev/null +++ b/tests/test_ccitt.py @@ -0,0 +1,103 @@ +import pytest +from pdf2zh.ccitt import BitParser, CCITTG4Parser, CCITTFaxDecoder, ccittfaxdecode + + +class TestBitParser: + def test_init(self): + parser = BitParser() + assert parser._pos == 0 + + def test_add_bits(self): + root = [None, None] + BitParser.add(root, 1, "1") + assert root[1] == 1 + assert root[0] is None + + def test_feedbytes_simple(self): + parser = BitParser() + parser._state = [None, None] + parser._accept = lambda x: [None, None] + parser.feedbytes(b"\x80") # 10000000 + assert parser._pos == 8 + + +class TestCCITTG4Parser: + def test_init(self): + parser = CCITTG4Parser(width=8) + assert parser.width == 8 + assert parser.bytealign is False + assert len(parser._curline) == 8 + + def test_mode_parsing(self): + parser = CCITTG4Parser(width=8) + # Test pass mode + new_state = parser._parse_mode("p") + assert new_state == parser.MODE + + # Test horizontal mode + new_state = parser._parse_mode("h") + assert new_state in (parser.WHITE, parser.BLACK) + + def test_vertical_coding(self): + parser = CCITTG4Parser(width=8) + parser._curpos = 0 + parser._color = 1 + parser._do_vertical(0) # No offset + assert parser._curpos >= 0 + + @pytest.mark.parametrize("n1,n2", [(1, 1), (2, 2), (3, 3)]) + def test_horizontal_coding(self, n1, n2): + parser = CCITTG4Parser(width=8) + parser._curpos = 0 + parser._color = 1 + parser._do_horizontal(n1, n2) + assert parser._curpos == n1 + n2 + + +class TestCCITTFaxDecoder: + def test_init(self): + decoder = CCITTFaxDecoder(width=8) + assert decoder.width == 8 + assert decoder.reversed is False + assert decoder._buf == b"" + + def test_decode_empty(self): + decoder = CCITTFaxDecoder(width=8) + result = decoder.close() + assert result == b"" + + def test_basic_decoding(self): + decoder = CCITTFaxDecoder(width=8) + decoder.feedbytes(b"\x00") + result = decoder.close() + assert isinstance(result, bytes) + + +def test_ccittfaxdecode(): + params = {"K": -1, "Columns": 8, "EncodedByteAlign": False, "BlackIs1": False} + result = ccittfaxdecode(b"\x00", params) + assert isinstance(result, bytes) + + with pytest.raises(Exception): + ccittfaxdecode(b"\x00", {"K": 0}) # Should raise for unsupported K value + + +@pytest.fixture +def basic_parser(): + return CCITTG4Parser(width=8) + + +class TestErrorHandling: + def test_invalid_mode(self, basic_parser): + with pytest.raises(CCITTG4Parser.InvalidData): + basic_parser._parse_mode("invalid") + + def test_eofb_detection(self, basic_parser): + with pytest.raises(CCITTG4Parser.EOFB): + basic_parser._parse_mode("e") + + def test_byte_alignment(self): + parser = CCITTG4Parser(width=8, bytealign=True) + parser._curpos = 8 # Simulate end of line + with pytest.raises(CCITTG4Parser.ByteSkip): + parser._flush_line() diff --git a/tests/test_cmapdb.py b/tests/test_cmapdb.py new file mode 100644 index 000000000..e622d460e --- /dev/null +++ b/tests/test_cmapdb.py @@ -0,0 +1,118 @@ +import pytest + +from pdf2zh.cmapdb import ( + CMapBase, + CMap, + IdentityCMap, + IdentityCMapByte, + UnicodeMap, + IdentityUnicodeMap, + FileCMap, + FileUnicodeMap, + PSLiteral, + PDFTypeError, +) + + +def test_cmap_base(): + cmap = CMapBase(WMode=0, Name="TestMap") + assert cmap.attrs["WMode"] == 0 + assert cmap.attrs["Name"] == "TestMap" + assert not cmap.is_vertical() + + cmap.set_attr("WMode", 1) + assert cmap.is_vertical() + + +def test_cmap(): + cmap = CMap(CMapName="TestCMap") + assert repr(cmap) == "" + + # Test code2cid mapping + cmap.code2cid = {1: 100, 2: {3: 200}} + decoded = list(cmap.decode(bytes([1]))) + assert decoded == [100] + + decoded = list(cmap.decode(bytes([2, 3]))) + assert decoded == [200] + + +def test_cmap_use_cmap(): + cmap1 = CMap() + cmap1.code2cid = {1: 100, 2: {3: 200}} + + cmap2 = CMap() + cmap2.use_cmap(cmap1) + assert cmap2.code2cid == {1: 100, 2: {3: 200}} + + +def test_identity_cmap(): + cmap = IdentityCMap() + result = cmap.decode(b"\x00\x01\x00\x02") + assert result == (1, 2) + + # Test empty input + assert cmap.decode(b"") == () + + +def test_identity_cmap_byte(): + cmap = IdentityCMapByte() + result = cmap.decode(b"\x01\x02\x03") + assert result == (1, 2, 3) + + # Test empty input + assert cmap.decode(b"") == () + + +def test_unicode_map(): + umap = UnicodeMap(CMapName="TestUnicode") + assert repr(umap) == "" + + umap.cid2unichr[100] = "A" + assert umap.get_unichr(100) == "A" + + +def test_identity_unicode_map(): + umap = IdentityUnicodeMap() + assert umap.get_unichr(65) == "A" # ASCII 65 = 'A' + assert umap.get_unichr(0x4E00) == "δΈ€" # CJK Unified Ideograph + + +def test_file_cmap(): + cmap = FileCMap() + cmap.add_code2cid("AB", 100) + + result = list(cmap.decode(bytes([ord("A"), ord("B")]))) + assert result == [100] + + +def test_file_unicode_map(): + umap = FileUnicodeMap() + + # Test with PSLiteral + umap.add_cid2unichr(100, PSLiteral("A")) + assert umap.get_unichr(100) == "A" + + # Test with bytes (UTF-16BE) + umap.add_cid2unichr(101, b"\x00A") + assert umap.get_unichr(101) == "A" + + # Test with integer + umap.add_cid2unichr(102, 65) # ASCII 65 = 'A' + assert umap.get_unichr(102) == "A" + + # Test invalid input + with pytest.raises(PDFTypeError): + umap.add_cid2unichr(103, 1.5) # type: ignore + + +def test_file_unicode_map_space_handling(): + umap = FileUnicodeMap() + + # Add regular space + umap.add_cid2unichr(100, 32) # space character + assert umap.get_unichr(100) == " " + + # Try to add non-breaking space - should be ignored + umap.add_cid2unichr(100, 0xA0) # non-breaking space + assert umap.get_unichr(100) == " " # should still be regular space diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py new file mode 100644 index 000000000..58ba408a5 --- /dev/null +++ b/tests/test_data_structures.py @@ -0,0 +1,35 @@ +import pytest +from pdf2zh.data_structures import NumberTree +from pdf2zh import settings + + +def test_number_tree_empty(): + tree = NumberTree({}) + assert tree.values == [] + + +def test_number_tree_leaf(): + tree = NumberTree({"Nums": [1, "a", 2, "b", 3, "c"]}) + assert tree.values == [(1, "a"), (2, "b"), (3, "c")] + + +def test_number_tree_kids(): + child1 = {"Nums": [1, "a", 2, "b"]} + child2 = {"Nums": [3, "c", 4, "d"]} + tree = NumberTree({"Kids": [child1, child2]}) + assert tree.values == [(1, "a"), (2, "b"), (3, "c"), (4, "d")] + + +def test_number_tree_unordered(): + tree = NumberTree({"Nums": [3, "c", 1, "a", 2, "b"]}) + if settings.STRICT: + with pytest.raises(Exception): + tree.values + else: + assert tree.values == [(1, "a"), (2, "b"), (3, "c")] + + +def test_number_tree_limits(): + tree = NumberTree({"Nums": [1, "a"], "Limits": [1, 1]}) + assert tree.values == [(1, "a")] + assert tree.limits == [1, 1] diff --git a/tests/test_encodingdb.py b/tests/test_encodingdb.py new file mode 100644 index 000000000..95d72b777 --- /dev/null +++ b/tests/test_encodingdb.py @@ -0,0 +1,64 @@ +import pytest +from pdf2zh.encodingdb import name2unicode, EncodingDB, PDFKeyError +from pdf2zh.psparser import PSLiteral + + +def test_name2unicode_basic(): + assert name2unicode("A") == "A" + assert name2unicode("dollar") == "$" + + +def test_name2unicode_composite(): + assert name2unicode("A_B") == "AB" + assert name2unicode("A_B_C") == "ABC" + + +def test_name2unicode_uni(): + assert name2unicode("uni0041") == "A" # Unicode A + assert name2unicode("uni00410042") == "AB" # Unicode AB + + +def test_name2unicode_u(): + assert name2unicode("u0041") == "A" # Unicode A + assert name2unicode("u1F600") == "πŸ˜€" # Unicode emoji + + +def test_name2unicode_invalid_type(): + with pytest.raises(PDFKeyError): + name2unicode(123) # type: ignore + + +def test_name2unicode_invalid_name(): + with pytest.raises(PDFKeyError): + name2unicode("invalid_name") + + +def test_name2unicode_invalid_unicode(): + with pytest.raises(PDFKeyError): + name2unicode("uniD800") # Surrogate pair range + + +def test_encoding_db_basic(): + assert "StandardEncoding" in EncodingDB.encodings + assert "MacRomanEncoding" in EncodingDB.encodings + assert "WinAnsiEncoding" in EncodingDB.encodings + assert "PDFDocEncoding" in EncodingDB.encodings + + +def test_encoding_db_get_encoding(): + encoding = EncodingDB.get_encoding("StandardEncoding") + assert isinstance(encoding, dict) + assert 65 in encoding # ASCII 'A' + + +def test_encoding_db_get_encoding_with_diff(): + diff = [0, PSLiteral("A"), PSLiteral("B")] + encoding = EncodingDB.get_encoding("StandardEncoding", diff) + assert encoding[0] == "A" + assert encoding[1] == "B" + + +def test_encoding_db_get_encoding_invalid(): + # Should return StandardEncoding for invalid encoding name + encoding = EncodingDB.get_encoding("InvalidEncoding") + assert encoding == EncodingDB.std2unicode