diff --git a/repoze/xmliter/serializer.py b/repoze/xmliter/serializer.py index b7d909e..6ff5d5f 100644 --- a/repoze/xmliter/serializer.py +++ b/repoze/xmliter/serializer.py @@ -7,7 +7,7 @@ doctype_re_u = re.compile(u"^]+>\\s*", re.MULTILINE) class XMLSerializer(object): - + def __init__(self, tree, serializer=None, pretty_print=False, doctype=None): if serializer is None: serializer = lxml.etree.tostring @@ -21,7 +21,10 @@ def __init__(self, tree, serializer=None, pretty_print=False, doctype=None): def serialize(self, encoding=None): # Defer to the xsl:output settings if appropriate if isinstance(self.tree, lxml.etree._XSLTResultTree): - result = str(self.tree) + if encoding is str: + result = str(self.tree) + else: + result = bytes(self.tree) else: result = self.serializer(self.tree, encoding=encoding, pretty_print=self.pretty_print) if self.doctype is not None: @@ -36,7 +39,7 @@ def __iter__(self): def __str__(self): return self.serialize(str) - + def __bytes__(self): return self.serialize() diff --git a/repoze/xmliter/tests.py b/repoze/xmliter/tests.py index ea17e91..1226885 100644 --- a/repoze/xmliter/tests.py +++ b/repoze/xmliter/tests.py @@ -8,9 +8,6 @@ import lxml.html import lxml.etree -import sys -if sys.version_info > (3,): - unicode = str class TestIterator(unittest.TestCase): @@ -34,7 +31,7 @@ def test_html_serialization(self): """Test HTML serialization.""" @decorator.lazy(serializer=lxml.html.tostring) - def app(a, b, c=u""): + def app(a, b, c=""): tree = self.create_tree() tree.find('body').attrib['class'] = " ".join((a, b, c)) return tree @@ -48,13 +45,13 @@ def app(a, b, c=u""): # With Unicode encoding: self.assertEqual( lxml.html.tostring(result.tree, encoding='unicode'), - u"".join(result.serialize(encoding=unicode))) + "".join(result.serialize(encoding=str))) def test_xml_serialization(self): """Test XML serialization.""" @decorator.lazy - def app(a, b, c=u""): + def app(a, b, c=""): tree = self.create_tree() tree.find('body').attrib['class'] = " ".join((a, b, c)) return tree @@ -68,7 +65,7 @@ def app(a, b, c=u""): # With Unicode encoding: self.assertEqual( lxml.etree.tostring(result.tree, encoding='unicode'), - u"".join(result.serialize(encoding=unicode))) + "".join(result.serialize(encoding=str))) def test_decorator_instancemethod(self): class test(object): @@ -86,7 +83,7 @@ def __call__(self, tree): self.assertEqual( lxml.etree.tostring(result.tree, encoding='unicode'), - u"".join(result.serialize(encoding=unicode))) + "".join(result.serialize(encoding=str))) def test_getXMLSerializer(self): t = utils.getXMLSerializer(self.create_iterable()) @@ -100,8 +97,8 @@ def test_getXMLSerializer(self): b"".join(t2)) self.assertEqual( - u"My homepageHello, wörld!", - u"".join(t2.serialize(encoding=unicode))) + "My homepageHello, wörld!", + "".join(t2.serialize(encoding=str))) def test_length(self): t = utils.getXMLSerializer(self.create_iterable()) @@ -120,8 +117,8 @@ def test_getHTMLSerializer(self): b"".join(t2).strip()) self.assertEqual( - u'\n\nMy homepage\nHello, wörld!\n\n', - u"".join(t2.serialize(encoding=unicode)).strip()) + '\n\nMy homepage\nHello, wörld!\n\n', + "".join(t2.serialize(encoding=str)).strip()) def test_getHTMLSerializer_doctype_xhtml_serializes_to_xhtml(self): t = utils.getHTMLSerializer(self.create_iterable(preamble='\n', body=''), pretty_print=True) @@ -135,8 +132,8 @@ def test_getHTMLSerializer_doctype_xhtml_serializes_to_xhtml(self): b"".join(t2).strip()) self.assertEqual( - u'\n\n \n \n My homepage\n \n Hello, wörld!\n', - u"".join(t2.serialize(encoding=unicode)).strip()) + '\n\n \n \n My homepage\n \n Hello, wörld!\n', + "".join(t2.serialize(encoding=str)).strip()) def test_xsl(self): t = utils.getHTMLSerializer(self.create_iterable(body='
')) @@ -156,6 +153,7 @@ def test_xsl(self): ''')) t.tree = transform(t.tree) self.assertTrue('
' in str(t)) + self.assertTrue(b'
' in bytes(t)) def test_replace_doctype(self): t = utils.getHTMLSerializer(self.create_iterable(preamble='\n', body=''), pretty_print=True, doctype="") @@ -169,8 +167,8 @@ def test_replace_doctype(self): b"".join(t2).strip()) self.assertEqual( - u'\n\n \n \n My homepage\n \n Hello, wörld!\n', - u"".join(t2.serialize(encoding=unicode)).strip()) + '\n\n \n \n My homepage\n \n Hello, wörld!\n', + "".join(t2.serialize(encoding=str)).strip()) def test_replace_doctype_blank(self): t = utils.getHTMLSerializer(self.create_iterable(preamble='\n', body=''), pretty_print=True, doctype="") @@ -184,5 +182,5 @@ def test_replace_doctype_blank(self): b"".join(t2).strip()) self.assertEqual( - u'\n \n \n My homepage\n \n Hello, wörld!\n', - u"".join(t2.serialize(encoding=unicode)).strip()) + '\n \n \n My homepage\n \n Hello, wörld!\n', + "".join(t2.serialize(encoding=str)).strip())