diff --git a/src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java b/src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java index 9a70ca6d..46c21b69 100644 --- a/src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java +++ b/src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java @@ -2,8 +2,11 @@ import org.tarantool.util.SQLStates; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.InputStream; import java.io.Reader; +import java.io.UnsupportedEncodingException; import java.math.BigDecimal; import java.net.URL; import java.sql.Array; @@ -31,6 +34,7 @@ public class SQLPreparedStatement extends SQLStatement implements PreparedStatement { private static final String INVALID_CALL_MESSAGE = "The method cannot be called on a PreparedStatement."; + private static final int STREAM_WRITE_CHUNK_SIZE = 4096; private final String sql; private final Map parameters; @@ -182,37 +186,40 @@ public void setTimestamp(int parameterIndex, Timestamp parameterValue, Calendar @Override public void setAsciiStream(int parameterIndex, InputStream parameterValue, int length) throws SQLException { - setParameter(parameterIndex, parameterValue); + setAsciiStream(parameterIndex, parameterValue, (long) length); } @Override - public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { - throw new SQLFeatureNotSupportedException(); + public void setAsciiStream(int parameterIndex, InputStream parameterValue) throws SQLException { + setCharStream(parameterIndex, parameterValue, Integer.MAX_VALUE, "ASCII"); } @Override - public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { - throw new SQLFeatureNotSupportedException(); + public void setAsciiStream(int parameterIndex, InputStream parameterValue, long length) throws SQLException { + ensureLengthLowerBound(length); + setCharStream(parameterIndex, parameterValue, length, "ASCII"); } @Override public void setUnicodeStream(int parameterIndex, InputStream parameterValue, int length) throws SQLException { - setParameter(parameterIndex, parameterValue); + ensureLengthLowerBound(length); + setCharStream(parameterIndex, parameterValue, length, "UTF-8"); } @Override public void setBinaryStream(int parameterIndex, InputStream parameterValue, int length) throws SQLException { - setParameter(parameterIndex, parameterValue); + setBinaryStream(parameterIndex, parameterValue, (long) length); } @Override - public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { - throw new SQLFeatureNotSupportedException(); + public void setBinaryStream(int parameterIndex, InputStream parameterValue, long length) throws SQLException { + ensureLengthLowerBound(length); + setBinStream(parameterIndex, parameterValue, length); } @Override - public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { - throw new SQLFeatureNotSupportedException(); + public void setBinaryStream(int parameterIndex, InputStream parameterValue) throws SQLException { + setBinStream(parameterIndex, parameterValue, Integer.MAX_VALUE); } @Override @@ -257,17 +264,18 @@ public boolean execute(String sql) throws SQLException { @Override public void setCharacterStream(int parameterIndex, Reader reader, int length) throws SQLException { - throw new SQLFeatureNotSupportedException(); + setCharacterStream(parameterIndex, reader, (long) length); } @Override public void setCharacterStream(int parameterIndex, Reader reader, long length) throws SQLException { - throw new SQLFeatureNotSupportedException(); + ensureLengthLowerBound(length); + setCharStream(parameterIndex, reader, length); } @Override public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { - throw new SQLFeatureNotSupportedException(); + setCharStream(parameterIndex, reader, Integer.MAX_VALUE); } @Override @@ -343,12 +351,12 @@ public void setNString(int parameterIndex, String parameterValue) throws SQLExce @Override public void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { - throw new SQLFeatureNotSupportedException(); + setCharacterStream(parameterIndex, value, length); } @Override public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { - throw new SQLFeatureNotSupportedException(); + setCharacterStream(parameterIndex, value); } @Override @@ -417,4 +425,71 @@ private Object[] toParametersList(Map parameters) throws SQLExc return objects; } + private void ensureLengthLowerBound(long length) throws SQLException { + if (length < 0) { + throw new SQLException("Stream size cannot be negative", SQLStates.INVALID_PARAMETER_VALUE.getSqlState()); + } + } + + private void ensureLengthUpperBound(long length) throws SQLException { + if (length > Integer.MAX_VALUE) { + throw new SQLException("Stream size is too large", SQLStates.INVALID_PARAMETER_VALUE.getSqlState()); + } + } + + private void setCharStream(int parameterIndex, + InputStream parameterValue, + long length, + String encoding) throws SQLException { + ensureLengthUpperBound(length); + try { + byte[] bytes = convertToBytes(parameterValue, length); + setParameter(parameterIndex, new String(bytes, 0, bytes.length, encoding)); + } catch (UnsupportedEncodingException e) { + throw new SQLException("Unsupported encoding", SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), e); + } + } + + private void setCharStream(int parameterIndex, Reader reader, long length) throws SQLException { + ensureLengthUpperBound(length); + try { + StringBuilder value = new StringBuilder(STREAM_WRITE_CHUNK_SIZE); + char[] buffer = new char[STREAM_WRITE_CHUNK_SIZE]; + int totalRead = 0; + int charsRead; + while (totalRead < length && + (charsRead = reader.read(buffer, 0, (int) Math.min(length - totalRead, STREAM_WRITE_CHUNK_SIZE))) != -1) { + value.append(buffer, 0, charsRead); + totalRead += charsRead; + } + setParameter(parameterIndex, value.toString()); + } catch (IOException e) { + throw new SQLException("Cannot read from the reader", SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), e); + } + } + + private void setBinStream(int parameterIndex, + InputStream parameterValue, + long length) throws SQLException { + ensureLengthUpperBound(length); + setBytes(parameterIndex, convertToBytes(parameterValue, length)); + } + + private byte[] convertToBytes(InputStream parameterValue, long length) throws SQLException { + try { + int bytesRead; + int totalRead = 0; + byte[] buffer = new byte[STREAM_WRITE_CHUNK_SIZE]; + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(STREAM_WRITE_CHUNK_SIZE); + while (totalRead < length && + (bytesRead = parameterValue.read(buffer, 0, (int) Math.min(length - totalRead, STREAM_WRITE_CHUNK_SIZE))) != -1) { + outputStream.write(buffer, 0, bytesRead); + totalRead += bytesRead; + } + return outputStream.toByteArray(); + } catch (IOException e) { + throw new SQLException("Cannot read stream", SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), e); + } + } + } diff --git a/src/main/java/org/tarantool/jdbc/SQLResultSet.java b/src/main/java/org/tarantool/jdbc/SQLResultSet.java index ac7b5db6..232ee9ec 100644 --- a/src/main/java/org/tarantool/jdbc/SQLResultSet.java +++ b/src/main/java/org/tarantool/jdbc/SQLResultSet.java @@ -336,7 +336,7 @@ public InputStream getAsciiStream(String columnLabel) throws SQLException { @Override public InputStream getUnicodeStream(int columnIndex) throws SQLException { String string = getString(columnIndex); - return string == null ? null : new ByteArrayInputStream(string.getBytes(Charset.forName("UTF-16"))); + return string == null ? null : new ByteArrayInputStream(string.getBytes(Charset.forName("UTF-8"))); } @Override diff --git a/src/test/java/org/tarantool/jdbc/JdbcPreparedStatementIT.java b/src/test/java/org/tarantool/jdbc/JdbcPreparedStatementIT.java index 4aad887f..46adb74b 100644 --- a/src/test/java/org/tarantool/jdbc/JdbcPreparedStatementIT.java +++ b/src/test/java/org/tarantool/jdbc/JdbcPreparedStatementIT.java @@ -1,5 +1,6 @@ package org.tarantool.jdbc; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -7,6 +8,10 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyObject; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.tarantool.TestAssumptions.assumeMinimalServerVersion; import org.tarantool.ServerVersion; @@ -21,6 +26,12 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.io.StringReader; +import java.nio.charset.StandardCharsets; import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.DriverManager; @@ -30,13 +41,14 @@ import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; import java.sql.Types; +import java.util.Arrays; import java.util.Collections; import java.util.List; public class JdbcPreparedStatementIT { private static final String[] INIT_SQL = new String[] { - "CREATE TABLE test(id INT PRIMARY KEY, val VARCHAR(100))", + "CREATE TABLE test(id INT PRIMARY KEY, val VARCHAR(100), bin_val SCALAR)", }; private static final String[] CLEAN_SQL = new String[] { @@ -108,7 +120,7 @@ public void testExecuteQuery() throws SQLException { @Test public void testExecuteWrongQuery() throws SQLException { - prep = conn.prepareStatement("INSERT INTO test VALUES (?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); prep.setInt(1, 200); prep.setString(2, "two hundred"); @@ -118,7 +130,7 @@ public void testExecuteWrongQuery() throws SQLException { @Test public void testExecuteUpdate() throws Exception { - prep = conn.prepareStatement("INSERT INTO test VALUES(?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES(?, ?)"); assertNotNull(prep); prep.setInt(1, 100); @@ -166,7 +178,7 @@ public void testExecuteReturnsResultSet() throws SQLException { @Test public void testExecuteReturnsUpdateCount() throws Exception { - prep = conn.prepareStatement("INSERT INTO test VALUES(?, ?), (?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES(?, ?), (?, ?)"); assertNotNull(prep); prep.setInt(1, 10); @@ -230,7 +242,7 @@ public void testIsWrapperFor() throws SQLException { @Test public void testSupportGeneratedKeys() throws SQLException { - prep = conn.prepareStatement("INSERT INTO test values (50, 'fifty')", Statement.NO_GENERATED_KEYS); + prep = conn.prepareStatement("INSERT INTO test(id, val) values (50, 'fifty')", Statement.NO_GENERATED_KEYS); assertFalse(prep.execute()); assertEquals(1, prep.getUpdateCount()); @@ -336,7 +348,7 @@ public void testMoreResultsWithResultSet() throws SQLException { @Test public void testMoreResultsWithUpdateCount() throws SQLException { - prep = conn.prepareStatement("INSERT INTO test VALUES (?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); prep.setInt(1, 9); prep.setString(2, "nine"); @@ -370,7 +382,7 @@ public void testMoreResultsButCloseAll() throws SQLException { assertThrows(SQLFeatureNotSupportedException.class, () -> prep.getMoreResults(Statement.CLOSE_ALL_RESULTS)); - prep = conn.prepareStatement("INSERT INTO test VALUES (?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); prep.setInt(1, 21); prep.setString(2, "twenty one"); prep.execute(); @@ -388,7 +400,7 @@ public void testMoreResultsButKeepCurrent() throws SQLException { assertThrows(SQLFeatureNotSupportedException.class, () -> prep.getMoreResults(Statement.KEEP_CURRENT_RESULT)); - prep = conn.prepareStatement("INSERT INTO test VALUES (?, ?)"); + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); prep.setInt(1, 22); prep.setString(2, "twenty two"); prep.execute(); @@ -560,6 +572,153 @@ void testPoolableStatus() throws SQLException { assertFalse(prep.isPoolable()); } + @Test + public void testSetAsciiStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 1); + InputStream asciiStream = new ByteArrayInputStream("one".getBytes("ASCII")); + prep.setAsciiStream(2, asciiStream); + + assertFalse(prep.execute()); + assertEquals("one", consoleSelect(1).get(1)); + } + + @Test + public void testSetAsciiLimitedStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 1); + InputStream asciiStream = new ByteArrayInputStream("one and two and even three".getBytes("ASCII")); + prep.setAsciiStream(2, asciiStream, 3); + + assertFalse(prep.execute()); + assertEquals("one", consoleSelect(1).get(1)); + } + + @Test + public void testSetNegativeAsciiStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 1); + InputStream asciiStream = new ByteArrayInputStream("one and two and even three".getBytes("ASCII")); + SQLException error = assertThrows(SQLException.class, () -> prep.setAsciiStream(2, asciiStream, -10)); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + } + + @Test + public void testSetBadStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + + InputStream throwingStream = mock(InputStream.class); + when(throwingStream.read(anyObject(), anyInt(), anyInt())).thenThrow(IOException.class); + + SQLException error = assertThrows( + SQLException.class, + () -> prep.setAsciiStream(2, throwingStream) + ); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + assertEquals(IOException.class, error.getCause().getClass()); + } + + @Test + public void testSetUnicodeLimitedStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 1); + InputStream unicodeStream = new ByteArrayInputStream("zéro one два みっつ 四 Fünf".getBytes("UTF-8")); + // zéro is 5 bytes length because é consists of tow bytes 0xC3 0xA9 + prep.setUnicodeStream(2, unicodeStream, 5); + + assertFalse(prep.execute()); + assertEquals("zéro", consoleSelect(1).get(1)); + } + + @Test + public void testSetNegativeUnicodeStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 1); + InputStream unicodeStream = new ByteArrayInputStream("one and two and even three".getBytes("UTF-8")); + SQLException error = assertThrows(SQLException.class, () -> prep.setUnicodeStream(2, unicodeStream, -9)); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + } + + @Test + public void testSetBinaryStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, bin_val) VALUES (?, ?)"); + prep.setInt(1, 1); + byte[] bytes = TestUtils.fromHex("00010203"); + prep.setBinaryStream(2, new ByteArrayInputStream(bytes)); + + assertFalse(prep.execute()); + assertArrayEquals(bytes, ((String) consoleSelect(1).get(2)).getBytes(StandardCharsets.US_ASCII)); + } + + @Test + public void testSetBinaryLimitedStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, bin_val) VALUES (?, ?)"); + prep.setInt(1, 1); + byte[] bytes = TestUtils.fromHex("00010203040506"); + prep.setBinaryStream(2, new ByteArrayInputStream(bytes), 2); + + assertFalse(prep.execute()); + assertArrayEquals( + Arrays.copyOf(bytes, 2), + ((String) consoleSelect(1).get(2)).getBytes(StandardCharsets.US_ASCII) + ); + } + + @Test + public void testSetNegativeBinaryStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, bin_val) VALUES (?, ?)"); + byte[] bytes = TestUtils.fromHex("00010203040506"); + SQLException error = assertThrows( + SQLException.class, + () -> prep.setBinaryStream(2, new ByteArrayInputStream(bytes), -4) + ); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + } + + @Test + public void testSetCharacterStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 2); + prep.setCharacterStream(2, new StringReader("two")); + + assertFalse(prep.execute()); + assertEquals("two", consoleSelect(2).get(1)); + } + + @Test + public void testSetCharacterLimitedStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + prep.setInt(1, 2); + prep.setCharacterStream(2, new StringReader("two or maybe four"), 3); + + assertFalse(prep.execute()); + assertEquals("two", consoleSelect(2).get(1)); + } + + @Test + public void testSetNegativeCharacterStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + SQLException error = assertThrows( + SQLException.class, + () -> prep.setCharacterStream(2, new StringReader("four"), -10) + ); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + } + + @Test + public void testSetBadCharacterStream() throws Exception { + prep = conn.prepareStatement("INSERT INTO test(id, val) VALUES (?, ?)"); + + Reader throwingReader = mock(Reader.class); + when(throwingReader.read(anyObject(), anyInt(), anyInt())).thenThrow(IOException.class); + + SQLException error = assertThrows( + SQLException.class, + () -> prep.setCharacterStream(2, throwingReader) + ); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), error.getSQLState()); + } + private List consoleSelect(Object key) { List list = testHelper.evaluate(TestUtils.toLuaSelect("TEST", key)); return list == null ? Collections.emptyList() : (List) list.get(0);