Skip to content

Commit 0985ac1

Browse files
To retry when retrieving the list of shards if there is a timeout
GitOrigin-RevId: 90bf2e8603eee559ebd911a9246ec4421ca42da3
1 parent 8bbf9d6 commit 0985ac1

File tree

4 files changed

+123
-21
lines changed

4 files changed

+123
-21
lines changed

misk-vitess/api/misk-vitess.api

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ public final class misk/vitess/ShardsKt {
103103
public static final fun shards (Lmisk/jdbc/DataSourceService;)Lcom/google/common/base/Supplier;
104104
}
105105

106+
public final class misk/vitess/ShardsLoader {
107+
public static final field INSTANCE Lmisk/vitess/ShardsLoader;
108+
public final fun shards (Lmisk/jdbc/DataSourceService;)Lcom/google/common/base/Supplier;
109+
}
110+
106111
public final class misk/vitess/TabletType : java/lang/Enum {
107112
public static final field Companion Lmisk/vitess/TabletType$Companion;
108113
public static final field PRIMARY Lmisk/vitess/TabletType;

misk-vitess/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ dependencies {
1111
api(libs.dockerApi)
1212
api(libs.guava)
1313
api(project(":misk-jdbc"))
14+
implementation(project(":misk-logging"))
15+
implementation(libs.loggingApi)
1416
implementation(libs.okio)
1517

1618
testFixturesApi(libs.docker)
Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
11
package misk.vitess
22

33
import com.google.common.base.Supplier
4-
import com.google.common.base.Suppliers
54
import misk.jdbc.DataSourceService
6-
import misk.jdbc.mapNotNull
7-
import java.sql.SQLRecoverableException
8-
import java.util.concurrent.TimeUnit
95

106
fun shards(dataSourceService: DataSourceService): Supplier<Set<Shard>> =
11-
Suppliers.memoizeWithExpiration({
12-
if (!dataSourceService.config().type.isVitess) {
13-
Shard.SINGLE_SHARD_SET
14-
} else {
15-
dataSourceService.dataSource.connection.use { connection ->
16-
connection.createStatement().use { s ->
17-
val shards = s.executeQuery("SHOW VITESS_SHARDS")
18-
.mapNotNull { rs -> Shard.parse(rs.getString(1)) }
19-
.toSet()
20-
if (shards.isEmpty()) {
21-
throw SQLRecoverableException("Failed to load list of shards")
22-
}
23-
shards
24-
}
25-
}
26-
}
27-
}, 5, TimeUnit.MINUTES)
7+
ShardsLoader.shards(dataSourceService)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package misk.vitess
2+
3+
import com.google.common.base.Supplier
4+
import com.google.common.base.Suppliers
5+
import misk.jdbc.DataSourceService
6+
import misk.jdbc.mapNotNull
7+
import misk.logging.getLogger
8+
import java.sql.SQLException
9+
import java.sql.SQLRecoverableException
10+
import java.sql.SQLTimeoutException
11+
import java.sql.SQLTransientException
12+
import java.util.concurrent.TimeUnit
13+
import kotlin.random.Random
14+
15+
object ShardsLoader {
16+
private val logger = getLogger<ShardsLoader>()
17+
18+
fun shards(dataSourceService: DataSourceService): Supplier<Set<Shard>> =
19+
Suppliers.memoizeWithExpiration({
20+
if (!dataSourceService.config().type.isVitess) {
21+
Shard.SINGLE_SHARD_SET
22+
} else {
23+
loadVitessShards(dataSourceService)
24+
}
25+
}, 5, TimeUnit.MINUTES)
26+
27+
private fun loadVitessShards(dataSourceService: DataSourceService): Set<Shard> {
28+
val maxRetries = 3
29+
var lastException: Exception? = null
30+
val retryableErrorMessages = listOf(
31+
"timeout",
32+
"context canceled",
33+
"connection closed"
34+
)
35+
36+
repeat(maxRetries) { attempt ->
37+
try {
38+
return dataSourceService.dataSource.connection.use { connection ->
39+
connection.createStatement().use { statement ->
40+
// Set aggressive 5-second timeout for fast failure
41+
statement.queryTimeout = 5
42+
43+
val resultSet = statement.executeQuery("SHOW VITESS_SHARDS")
44+
45+
val shards = resultSet.mapNotNull { rs ->
46+
try {
47+
val shardName = rs.getString(1)
48+
Shard.parse(shardName)
49+
} catch (e: Exception) {
50+
logger.warn(e) { "Failed to parse shard from result: ${rs.getString(1)}" }
51+
null
52+
}
53+
}.toSet()
54+
55+
if (shards.isEmpty()) {
56+
throw SQLRecoverableException("SHOW VITESS_SHARDS returned empty result set")
57+
}
58+
59+
shards
60+
}
61+
}
62+
63+
} catch (e: Exception) {
64+
lastException = e
65+
66+
val isRetryableException =
67+
e.cause is SQLTimeoutException || e is SQLTimeoutException ||
68+
e.cause is SQLRecoverableException || e is SQLRecoverableException
69+
e.cause is SQLTransientException || e is SQLTransientException
70+
71+
val isRetryableMessage = retryableErrorMessages.any { errorMessage ->
72+
e.message?.contains(errorMessage, ignoreCase = true) == true
73+
}
74+
75+
val isRetryable = isRetryableException || isRetryableMessage
76+
77+
logger.warn(e) {
78+
"Failed to load Vitess shards on attempt ${attempt + 1}/$maxRetries" +
79+
if (isRetryable) " (retryable exception detected)" else ""
80+
}
81+
82+
// Only retry on timeouts/connectivity issues, and not on the last attempt
83+
if (isRetryable && attempt < maxRetries - 1) {
84+
val backoffMs = calculateBackoff(attempt)
85+
logger.info { "Retrying shard loading in ${backoffMs}ms..." }
86+
87+
try {
88+
Thread.sleep(backoffMs)
89+
} catch (interrupted: InterruptedException) {
90+
Thread.currentThread().interrupt()
91+
logger.error { "Shard loading interrupted" }
92+
throw interrupted
93+
}
94+
} else {
95+
throw e // immediately rethrow without retrying
96+
}
97+
}
98+
}
99+
100+
// All retries exhausted -- shouldn't reach here
101+
throw SQLException(
102+
"Failed to load Vitess shards after $maxRetries attempts. " +
103+
"Check Vitess connectivity and query timeout settings.",
104+
lastException
105+
)
106+
}
107+
108+
private fun calculateBackoff(attempt: Int): Long {
109+
val baseDelayMs = 500L // Start with 500ms
110+
val maxDelayMs = 5000L // Cap at 5 seconds
111+
val exponentialDelay = baseDelayMs * (1L shl attempt) // 500ms, 1s, 2s, 4s...
112+
val jitter = Random.nextLong(0, 200) // Add up to 200ms jitter
113+
return minOf(exponentialDelay + jitter, maxDelayMs)
114+
}
115+
}

0 commit comments

Comments
 (0)