Skip to content

Support generated keys from INSERT query #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/org/tarantool/Key.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public enum Key implements Callable<Integer> {
SQL_BIND(0x41),
SQL_OPTIONS(0x42),
SQL_INFO(0x42),
SQL_ROW_COUNT(0);
SQL_ROW_COUNT(0x00),
SQL_INFO_AUTOINCREMENT_IDS(0x01);

int id;

Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/tarantool/SqlProtoUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.tarantool.protocol.TarantoolPacket;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -41,7 +42,7 @@ public static List<SQLMetaData> getSQLMetadata(TarantoolPacket pack) {
return values;
}

public static Long getSqlRowCount(TarantoolPacket pack) {
public static Long getSQLRowCount(TarantoolPacket pack) {
Map<Key, Object> info = (Map<Key, Object>) pack.getBody().get(Key.SQL_INFO.getId());
Number rowCount;
if (info != null && (rowCount = ((Number) info.get(Key.SQL_ROW_COUNT.getId()))) != null) {
Expand All @@ -50,6 +51,15 @@ public static Long getSqlRowCount(TarantoolPacket pack) {
return null;
}

public static List<Integer> getSQLAutoIncrementIds(TarantoolPacket pack) {
Map<Key, Object> info = (Map<Key, Object>) pack.getBody().get(Key.SQL_INFO.getId());
if (info != null) {
List<Integer> generatedIds = (List<Integer>) info.get(Key.SQL_INFO_AUTOINCREMENT_IDS.getId());
return generatedIds == null ? Collections.emptyList() : generatedIds;
}
return Collections.emptyList();
}

public static class SQLMetaData {
private String name;
private TarantoolSqlType type;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/TarantoolClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ protected void complete(TarantoolPacket packet, TarantoolOp<?> future) {
}

protected void completeSql(TarantoolOp<?> future, TarantoolPacket pack) {
Long rowCount = SqlProtoUtils.getSqlRowCount(pack);
Long rowCount = SqlProtoUtils.getSQLRowCount(pack);
if (rowCount != null) {
((TarantoolOp) future).complete(rowCount);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/TarantoolConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void close() {
@Override
public Long update(String sql, Object... bind) {
TarantoolPacket pack = sql(sql, bind);
return SqlProtoUtils.getSqlRowCount(pack);
return SqlProtoUtils.getSQLRowCount(pack);
}

@Override
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/org/tarantool/jdbc/SQLConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* <p>
* Supports creating {@link Statement} and {@link PreparedStatement} instances
*/
public class SQLConnection implements Connection {
public class SQLConnection implements TarantoolConnection {

private static final int UNSET_HOLDABILITY = 0;
private static final String PING_QUERY = "SELECT 1";
Expand Down Expand Up @@ -148,10 +148,7 @@ public PreparedStatement prepareStatement(String sql,
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
}
return prepareStatement(sql);
return new SQLPreparedStatement(this, sql, autoGeneratedKeys);
}

@Override
Expand Down Expand Up @@ -527,14 +524,17 @@ public int getNetworkTimeout() throws SQLException {
return (int) client.getOperationTimeout();
}

protected SQLResultHolder execute(long timeout, SQLQueryHolder query) throws SQLException {
@Override
public SQLResultHolder execute(long timeout, SQLQueryHolder query) throws SQLException {
checkNotClosed();
return (useNetworkTimeout(timeout))
? executeWithNetworkTimeout(query)
: executeWithQueryTimeout(timeout, query);
}

protected SQLBatchResultHolder executeBatch(long timeout, List<SQLQueryHolder> queries) throws SQLException {
@Override
public SQLBatchResultHolder executeBatch(long timeout, List<SQLQueryHolder> queries)
throws SQLException {
checkNotClosed();
SQLTarantoolClientImpl.SQLRawOps sqlOps = client.sqlRawOps();
SQLBatchResultHolder batchResult = useNetworkTimeout(timeout)
Expand Down Expand Up @@ -810,10 +810,10 @@ SQLRawOps sqlRawOps() {

@Override
protected void completeSql(TarantoolOp<?> future, TarantoolPacket pack) {
Long rowCount = SqlProtoUtils.getSqlRowCount(pack);
Long rowCount = SqlProtoUtils.getSQLRowCount(pack);
SQLResultHolder result = (rowCount == null)
? SQLResultHolder.ofQuery(SqlProtoUtils.getSQLMetadata(pack), SqlProtoUtils.getSQLData(pack))
: SQLResultHolder.ofUpdate(rowCount.intValue());
: SQLResultHolder.ofUpdate(rowCount.intValue(), SqlProtoUtils.getSQLAutoIncrementIds(pack));
((TarantoolOp) future).complete(result);
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/tarantool/jdbc/SQLDatabaseMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ public boolean supportsMultipleOpenResults() throws SQLException {

@Override
public boolean supportsGetGeneratedKeys() throws SQLException {
return false;
return true;
}

@Override
Expand Down Expand Up @@ -1104,7 +1104,7 @@ private ResultSet asEmptyMetadataResultSet(List<TupleTwo<String, TarantoolSqlTyp

@Override
public boolean generatedKeyAlwaysReturned() throws SQLException {
return false;
return true;
}

@Override
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ public class SQLPreparedStatement extends SQLStatement implements PreparedStatem

private final String sql;
private final Map<Integer, Object> parameters;

private final int autoGeneratedKeys;
private List<Map<Integer, Object>> batchParameters = new ArrayList<>();

public SQLPreparedStatement(SQLConnection connection, String sql) throws SQLException {
public SQLPreparedStatement(SQLConnection connection, String sql, int autoGeneratedKeys) throws SQLException {
super(connection);
this.sql = sql;
this.parameters = new HashMap<>();
this.autoGeneratedKeys = autoGeneratedKeys;
setPoolable(true);
}

Expand All @@ -52,13 +53,14 @@ public SQLPreparedStatement(SQLConnection connection,
super(connection, resultSetType, resultSetConcurrency, resultSetHoldability);
this.sql = sql;
this.parameters = new HashMap<>();
this.autoGeneratedKeys = NO_GENERATED_KEYS;
setPoolable(true);
}

@Override
public ResultSet executeQuery() throws SQLException {
checkNotClosed();
if (!executeInternal(sql, toParametersList(parameters))) {
if (!executeInternal(autoGeneratedKeys, sql, toParametersList(parameters))) {
throw new SQLException("No results were returned", SQLStates.NO_DATA.getSqlState());
}
return resultSet;
Expand All @@ -73,7 +75,7 @@ public ResultSet executeQuery(String sql) throws SQLException {
@Override
public int executeUpdate() throws SQLException {
checkNotClosed();
if (executeInternal(sql, toParametersList(parameters))) {
if (executeInternal(autoGeneratedKeys, sql, toParametersList(parameters))) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
Expand Down Expand Up @@ -244,7 +246,7 @@ private void setParameter(int parameterIndex, Object value) throws SQLException
@Override
public boolean execute() throws SQLException {
checkNotClosed();
return executeInternal(sql, toParametersList(parameters));
return executeInternal(autoGeneratedKeys, sql, toParametersList(parameters));
}

@Override
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/org/tarantool/jdbc/SQLResultHolder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,29 @@ public class SQLResultHolder {
private final List<SqlProtoUtils.SQLMetaData> sqlMetadata;
private final List<List<Object>> rows;
private final int updateCount;
private final List<Integer> generatedIds;

public SQLResultHolder(List<SqlProtoUtils.SQLMetaData> sqlMetadata, List<List<Object>> rows, int updateCount) {
public SQLResultHolder(List<SqlProtoUtils.SQLMetaData> sqlMetadata,
List<List<Object>> rows,
int updateCount,
List<Integer> generatedIds) {
this.sqlMetadata = sqlMetadata;
this.rows = rows;
this.updateCount = updateCount;
this.generatedIds = generatedIds;
}

public static SQLResultHolder ofQuery(final List<SqlProtoUtils.SQLMetaData> sqlMetadata,
final List<List<Object>> rows) {
return new SQLResultHolder(sqlMetadata, rows, NO_UPDATE_COUNT);
return new SQLResultHolder(sqlMetadata, rows, NO_UPDATE_COUNT, Collections.emptyList());
}

public static SQLResultHolder ofEmptyQuery() {
return ofQuery(Collections.emptyList(), Collections.emptyList());
}

public static SQLResultHolder ofUpdate(int updateCount) {
return new SQLResultHolder(null, null, updateCount);
public static SQLResultHolder ofUpdate(int updateCount, List<Integer> generatedIds) {
return new SQLResultHolder(null, null, updateCount, generatedIds);
}

public List<SqlProtoUtils.SQLMetaData> getSqlMetadata() {
Expand All @@ -48,6 +53,10 @@ public int getUpdateCount() {
return updateCount;
}

public List<Integer> getGeneratedIds() {
return generatedIds;
}

public boolean isQueryResult() {
return sqlMetadata != null && rows != null;
}
Expand Down
64 changes: 40 additions & 24 deletions src/main/java/org/tarantool/jdbc/SQLStatement.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.tarantool.jdbc;

import org.tarantool.SqlProtoUtils;
import org.tarantool.jdbc.type.TarantoolSqlType;
import org.tarantool.util.JdbcConstants;
import org.tarantool.util.SQLStates;

Expand All @@ -13,6 +15,7 @@
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -27,13 +30,17 @@
*/
public class SQLStatement implements TarantoolStatement {

protected final SQLConnection connection;
private static final String GENERATED_KEY_COLUMN_NAME = "GENERATED_KEY";

protected final TarantoolConnection connection;
private final SQLResultSet emptyGeneratedKeys;

/**
* Current result set / update count associated to this statement.
*/
protected SQLResultSet resultSet;
protected int updateCount;
protected SQLResultSet generatedKeys;

private List<String> batchQueries = new ArrayList<>();

Expand Down Expand Up @@ -61,10 +68,12 @@ public class SQLStatement implements TarantoolStatement {
private final AtomicBoolean isClosed = new AtomicBoolean(false);

protected SQLStatement(SQLConnection sqlConnection) throws SQLException {
this.connection = sqlConnection;
this.resultSetType = ResultSet.TYPE_FORWARD_ONLY;
this.resultSetConcurrency = ResultSet.CONCUR_READ_ONLY;
this.resultSetHoldability = sqlConnection.getHoldability();
this(
sqlConnection,
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY,
sqlConnection.getHoldability()
);
}

protected SQLStatement(SQLConnection sqlConnection,
Expand All @@ -75,37 +84,34 @@ protected SQLStatement(SQLConnection sqlConnection,
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
this.emptyGeneratedKeys = this.generatedKeys = executeGeneratedKeys(Collections.emptyList());
}

@Override
public ResultSet executeQuery(String sql) throws SQLException {
checkNotClosed();
if (!executeInternal(sql)) {
if (!executeInternal(NO_GENERATED_KEYS, sql)) {
throw new SQLException("No results were returned", SQLStates.NO_DATA.getSqlState());
}
return resultSet;
}

@Override
public int executeUpdate(String sql) throws SQLException {
checkNotClosed();
if (executeInternal(sql)) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
);
}
return updateCount;
return executeUpdate(sql, NO_GENERATED_KEYS);
}

@Override
public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
if (executeInternal(autoGeneratedKeys, sql)) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
);
}
return executeUpdate(sql);
return updateCount;
}

@Override
Expand Down Expand Up @@ -195,17 +201,14 @@ public void setCursorName(String name) throws SQLException {
@Override
public boolean execute(String sql) throws SQLException {
checkNotClosed();
return executeInternal(sql);
return executeInternal(NO_GENERATED_KEYS, sql);
}

@Override
public boolean execute(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
}
return execute(sql);
return executeInternal(autoGeneratedKeys, sql);
}

@Override
Expand Down Expand Up @@ -321,7 +324,7 @@ public Connection getConnection() throws SQLException {
@Override
public ResultSet getGeneratedKeys() throws SQLException {
checkNotClosed();
return new SQLResultSet(SQLResultHolder.ofEmptyQuery(), this);
return generatedKeys;
}

@Override
Expand Down Expand Up @@ -401,6 +404,7 @@ protected void discardLastResults() throws SQLException {
clearWarnings();
updateCount = -1;
resultSet = null;
generatedKeys = emptyGeneratedKeys;

if (lastResultSet != null) {
try {
Expand All @@ -419,7 +423,7 @@ protected void discardLastResults() throws SQLException {
*
* @return {@code true}, if the result is a ResultSet object;
*/
protected boolean executeInternal(String sql, Object... params) throws SQLException {
protected boolean executeInternal(int autoGeneratedKeys, String sql, Object... params) throws SQLException {
discardLastResults();
SQLResultHolder holder;
try {
Expand All @@ -433,6 +437,9 @@ protected boolean executeInternal(String sql, Object... params) throws SQLExcept
resultSet = new SQLResultSet(holder, this);
}
updateCount = holder.getUpdateCount();
if (autoGeneratedKeys == Statement.RETURN_GENERATED_KEYS) {
generatedKeys = executeGeneratedKeys(holder.getGeneratedIds());
}
return holder.isQueryResult();
}

Expand Down Expand Up @@ -474,4 +481,13 @@ protected void checkNotClosed() throws SQLException {
}
}

protected SQLResultSet executeGeneratedKeys(List<Integer> generatedKeys) throws SQLException {
SqlProtoUtils.SQLMetaData sqlMetaData =
new SqlProtoUtils.SQLMetaData(GENERATED_KEY_COLUMN_NAME, TarantoolSqlType.INTEGER);
List<List<Object>> rows = generatedKeys.stream()
.map(Collections::<Object>singletonList)
.collect(Collectors.toList());
return createResultSet(SQLResultHolder.ofQuery(Collections.singletonList(sqlMetaData), rows));
}

}
Loading