Skip to content

Commit 1cc9ea8

Browse files
committed
scala.xml.XML allows overriding the SAXParser used via the withSAXParser() method. Some XML parsing customizations require changing the XMLReader contained inside every SAXParser (e.g., adding an XMLFilter). This pull request introduces an additional extension point XMLLoader.reader and a method XML.withXMLReader() for such a purpose.
Also, ErrorHandler and EntityResolver configured externally are no longer wiped out before parsing the XML.
1 parent 3f39bdd commit 1cc9ea8

File tree

5 files changed

+84
-32
lines changed

5 files changed

+84
-32
lines changed

jvm/src/test/scala/scala/xml/XMLTest.scala

+28
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,34 @@ class XMLTestJVM {
657657
def namespaceAware2: Unit =
658658
roundtrip(namespaceAware = true, """<book xmlns="http://docbook.org/ns/docbook" xmlns:xi="http://www.w3.org/2001/XInclude"><svg xmlns:svg="http://www.w3.org/2000/svg"/></book>""")
659659

660+
@UnitTest
661+
def useXMLReaderWithXMLFilter(): Unit = {
662+
val parent: org.xml.sax.XMLReader = javax.xml.parsers.SAXParserFactory.newInstance.newSAXParser.getXMLReader
663+
val filter: org.xml.sax.XMLFilter = new org.xml.sax.helpers.XMLFilterImpl(parent) {
664+
override def characters(ch: Array[Char], start: Int, length: Int): Unit = {
665+
for (i <- 0 until length) if (ch(start+i) == 'a') ch(start+i) = 'b'
666+
super.characters(ch, start, length)
667+
}
668+
}
669+
assertEquals(XML.withXMLReader(filter).loadString("<a>caffeeaaay</a>").toString, "<a>cbffeebbby</a>")
670+
}
671+
672+
@UnitTest
673+
def checkThatErrorHandlerIsNotOverwritten(): Unit = {
674+
var gotAnError: Boolean = false
675+
XML.reader.setErrorHandler(new org.xml.sax.ErrorHandler {
676+
override def warning(e: SAXParseException): Unit = gotAnError = true
677+
override def error(e: SAXParseException): Unit = gotAnError = true
678+
override def fatalError(e: SAXParseException): Unit = gotAnError = true
679+
})
680+
try {
681+
XML.loadString("<a>")
682+
} catch {
683+
case _: org.xml.sax.SAXParseException =>
684+
}
685+
assertTrue(gotAnError)
686+
}
687+
660688
@UnitTest
661689
def nodeSeqNs: Unit = {
662690
val x = {

shared/src/main/scala/scala/xml/XML.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ package scala
1414
package xml
1515

1616
import factory.XMLLoader
17-
import java.io.{ File, FileDescriptor, FileInputStream, FileOutputStream }
18-
import java.io.{ InputStream, Reader, StringReader }
17+
import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream}
18+
import java.io.{InputStream, Reader, StringReader}
1919
import java.nio.channels.Channels
2020
import scala.util.control.Exception.ultimately
2121

@@ -72,6 +72,10 @@ object XML extends XMLLoader[Elem] {
7272
def withSAXParser(p: SAXParser): XMLLoader[Elem] =
7373
new XMLLoader[Elem] { override val parser: SAXParser = p }
7474

75+
/** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */
76+
def withXMLReader(r: XMLReader): XMLLoader[Elem] =
77+
new XMLLoader[Elem] { override val reader: XMLReader = r }
78+
7579
/**
7680
* Saves a node to a file with given filename using given encoding
7781
* optionally with xmldecl and doctype declaration.

shared/src/main/scala/scala/xml/factory/XMLLoader.scala

+46-28
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ package scala
1414
package xml
1515
package factory
1616

17-
import org.xml.sax.SAXNotRecognizedException
17+
import org.xml.sax.{SAXNotRecognizedException, XMLReader}
1818
import javax.xml.parsers.SAXParserFactory
1919
import parsing.{FactoryAdapter, NoBindingFactoryAdapter}
2020
import java.io.{File, FileDescriptor, InputStream, Reader}
@@ -46,59 +46,77 @@ trait XMLLoader[T <: Node] {
4646
/* Override this to use a different SAXParser. */
4747
def parser: SAXParser = parserInstance.get
4848

49+
/* Override this to use a different XMLReader. */
50+
def reader: XMLReader = parser.getXMLReader
51+
4952
/**
5053
* Loads XML from the given InputSource, using the supplied parser.
5154
* The methods available in scala.xml.XML use the XML parser in the JDK.
5255
*/
53-
def loadXML(source: InputSource, parser: SAXParser): T = {
54-
val result: FactoryAdapter = parse(source, parser)
56+
def loadXML(source: InputSource, parser: SAXParser): T = loadXML(source, parser.getXMLReader)
57+
58+
def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(source, parser.getXMLReader)
59+
60+
private def loadXML(source: InputSource, reader: XMLReader): T = {
61+
val result: FactoryAdapter = parse(source, reader)
5562
result.rootElem.asInstanceOf[T]
5663
}
57-
58-
def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = {
59-
val result: FactoryAdapter = parse(source, parser)
64+
65+
private def loadXMLNodes(source: InputSource, reader: XMLReader): Seq[Node] = {
66+
val result: FactoryAdapter = parse(source, reader)
6067
result.prolog ++ (result.rootElem :: result.epilogue)
6168
}
6269

63-
private def parse(source: InputSource, parser: SAXParser): FactoryAdapter = {
70+
private def parse(source: InputSource, reader: XMLReader): FactoryAdapter = {
71+
if (source == null) throw new IllegalArgumentException("InputSource cannot be null")
72+
6473
val result: FactoryAdapter = adapter
6574

75+
reader.setContentHandler(result)
76+
reader.setDTDHandler(result)
77+
/* Do not overwrite pre-configured EntityResolver. */
78+
if (reader.getEntityResolver == null) reader.setEntityResolver(result)
79+
/* Do not overwrite pre-configured ErrorHandler. */
80+
if (reader.getErrorHandler == null) reader.setErrorHandler(result)
81+
6682
try {
67-
parser.setProperty("http://xml.org/sax/properties/lexical-handler", result)
83+
reader.setProperty("http://xml.org/sax/properties/lexical-handler", result)
6884
} catch {
6985
case _: SAXNotRecognizedException =>
7086
}
7187

7288
result.scopeStack = TopScope :: result.scopeStack
73-
parser.parse(source, result)
89+
reader.parse(source)
7490
result.scopeStack = result.scopeStack.tail
7591

7692
result
7793
}
7894

95+
/** loads XML from given InputSource. */
96+
def load(source: InputSource): T = loadXML(source, reader)
97+
7998
/** Loads XML from the given file, file descriptor, or filename. */
80-
def loadFile(file: File): T = loadXML(fromFile(file), parser)
81-
def loadFile(fd: FileDescriptor): T = loadXML(fromFile(fd), parser)
82-
def loadFile(name: String): T = loadXML(fromFile(name), parser)
99+
def loadFile(file: File): T = load(fromFile(file))
100+
def loadFile(fd: FileDescriptor): T = load(fromFile(fd))
101+
def loadFile(name: String): T = load(fromFile(name))
83102

84-
/** loads XML from given InputStream, Reader, sysID, InputSource, or URL. */
85-
def load(is: InputStream): T = loadXML(fromInputStream(is), parser)
86-
def load(reader: Reader): T = loadXML(fromReader(reader), parser)
87-
def load(sysID: String): T = loadXML(fromSysId(sysID), parser)
88-
def load(source: InputSource): T = loadXML(source, parser)
89-
def load(url: URL): T = loadXML(fromInputStream(url.openStream()), parser)
103+
/** loads XML from given InputStream, Reader, sysID, or URL. */
104+
def load(is: InputStream): T = load(fromInputStream(is))
105+
def load(reader: Reader): T = load(fromReader(reader))
106+
def load(sysID: String): T = load(fromSysId(sysID))
107+
def load(url: URL): T = load(fromInputStream(url.openStream()))
90108

91109
/** Loads XML from the given String. */
92-
def loadString(string: String): T = loadXML(fromString(string), parser)
110+
def loadString(string: String): T = load(fromString(string))
93111

94112
/** Load XML nodes, including comments and processing instructions that precede and follow the root element. */
95-
def loadFileNodes(file: File): Seq[Node] = loadXMLNodes(fromFile(file), parser)
96-
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadXMLNodes(fromFile(fd), parser)
97-
def loadFileNodes(name: String): Seq[Node] = loadXMLNodes(fromFile(name), parser)
98-
def loadNodes(is: InputStream): Seq[Node] = loadXMLNodes(fromInputStream(is), parser)
99-
def loadNodes(reader: Reader): Seq[Node] = loadXMLNodes(fromReader(reader), parser)
100-
def loadNodes(sysID: String): Seq[Node] = loadXMLNodes(fromSysId(sysID), parser)
101-
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, parser)
102-
def loadNodes(url: URL): Seq[Node] = loadXMLNodes(fromInputStream(url.openStream()), parser)
103-
def loadStringNodes(string: String): Seq[Node] = loadXMLNodes(fromString(string), parser)
113+
def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, reader)
114+
def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file))
115+
def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadNodes(fromFile(fd))
116+
def loadFileNodes(name: String): Seq[Node] = loadNodes(fromFile(name))
117+
def loadNodes(is: InputStream): Seq[Node] = loadNodes(fromInputStream(is))
118+
def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader))
119+
def loadNodes(sysID: String): Seq[Node] = loadNodes(fromSysId(sysID))
120+
def loadNodes(url: URL): Seq[Node] = loadNodes(fromInputStream(url.openStream()))
121+
def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string))
104122
}

shared/src/main/scala/scala/xml/package.scala

+1
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,6 @@ package object xml {
8080
type SAXParseException = org.xml.sax.SAXParseException
8181
type EntityResolver = org.xml.sax.EntityResolver
8282
type InputSource = org.xml.sax.InputSource
83+
type XMLReader = org.xml.sax.XMLReader
8384
type SAXParser = javax.xml.parsers.SAXParser
8485
}

shared/src/main/scala/scala/xml/parsing/MarkupParser.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ trait MarkupParser extends MarkupParserCommon with TokenTests {
9898
var extIndex = -1
9999

100100
/** holds temporary values of pos */
101-
// Note: this is clearly an override, but if marked as such it causes a "...cannot override a mutable variable"
102-
// error with Scala 3; does it work with Scala 3 if not explicitly marked as an override remains to be seen...
101+
// Note: if marked as an override, this causes a "...cannot override a mutable variable" error with Scala 3;
102+
// SethTisue noted on Oct 14, 2021 that lampepfl/dotty#13744 should fix it - and it probably did,
103+
// but Scala XML still builds against Scala 3 version that has this bug, so this still can not be marked as an override :(
103104
var tmppos: Int = _
104105

105106
/** holds the next character */

0 commit comments

Comments
 (0)