diff --git a/src/main/java/io/github/whitemagic2014/tts/TTS.java b/src/main/java/io/github/whitemagic2014/tts/TTS.java index eabc1b8..73d2b9a 100644 --- a/src/main/java/io/github/whitemagic2014/tts/TTS.java +++ b/src/main/java/io/github/whitemagic2014/tts/TTS.java @@ -3,6 +3,7 @@ import io.github.whitemagic2014.tts.bean.TransRecord; import io.github.whitemagic2014.tts.bean.Voice; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; import java.io.File; import java.net.URISyntaxException; @@ -17,7 +18,13 @@ import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy; +import java.util.concurrent.TimeUnit; public class TTS { @@ -66,6 +73,9 @@ public class TTS { */ private final Map websocketMap = new ConcurrentHashMap<>(); + + private ExecutorService executor; + public TTS(Voice voice) { this(voice, null); } @@ -166,8 +176,22 @@ public String trans() { public void batchTrans() { init(); - Stream stream = parallelThreadSize > 1 ? recordList.parallelStream() : recordList.stream(); - stream.forEach(record -> doTrans(record.getContent(), record.getFilename())); + if (parallelThreadSize > 1) { + List> futureList = new ArrayList<>(); + for (TransRecord record : recordList) { + Future future = executor.submit(() -> doTrans(record.getContent(), record.getFilename())); + futureList.add(future); + } + for (Future future : futureList) { + try { + future.get(); + } catch (InterruptedException | ExecutionException e) { + throw new IllegalStateException(e); + } + } + } else { + recordList.forEach(record -> doTrans(record.getContent(), record.getFilename())); + } } /** @@ -208,8 +232,11 @@ private void init() { if (!storageFolder.exists()) { storageFolder.mkdirs(); } - if (parallelThreadSize > 1) { - System.setProperty("java.util.concurrent.ForkJoinPool.common.parallelism", String.valueOf(parallelThreadSize)); + if (executor == null && parallelThreadSize > 1) { + executor = new ThreadPoolExecutor(0, parallelThreadSize, 1, TimeUnit.MINUTES, + new LinkedBlockingQueue<>(parallelThreadSize), + new BasicThreadFactory.Builder().daemon(true).namingPattern("trans-worker-%d").build(), + new CallerRunsPolicy()); } }