diff --git a/src/main/java/rx/observables/StringObservable.java b/src/main/java/rx/observables/StringObservable.java index a87806a..763b5c0 100644 --- a/src/main/java/rx/observables/StringObservable.java +++ b/src/main/java/rx/observables/StringObservable.java @@ -18,6 +18,7 @@ import rx.Observable; import rx.Observable.OnSubscribe; import rx.Observable.Operator; +import rx.Producer; import rx.Subscriber; import rx.functions.Action1; import rx.functions.Func0; @@ -38,8 +39,10 @@ import java.nio.charset.CodingErrorAction; import java.util.Arrays; import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicLong; import java.util.regex.Pattern; + public class StringObservable { /** * Reads from the bytes from a source {@link InputStream} and outputs {@link Observable} of @@ -114,28 +117,106 @@ public void call(S resource) { * internal buffer size * @return the Observable containing read byte arrays from the input */ - public static Observable from(final InputStream i, final int size) { + public static Observable from(final InputStream is, final int size) { return Observable.create(new OnSubscribe() { @Override - public void call(Subscriber o) { + public void call(Subscriber subscriber) { + subscriber.setProducer(new InputStreamProducer(is, subscriber, size)); + } + }); + } + + private static class InputStreamProducer implements Producer { + + private final AtomicLong requested = new AtomicLong(0); + private final InputStream is; + private final Subscriber subscriber; + private int size; + + InputStreamProducer(InputStream is, Subscriber subscriber, int size) { + this.is = is; + this.subscriber = subscriber; + this.size = size; + } + + @Override + public void request(long n) { + try { + if (requested.get() == Long.MAX_VALUE) + // already started with fast path + return; + else if (n == Long.MAX_VALUE) { + // fast path + requestAll(); + } else + requestSome(n); + } catch (RuntimeException e) { + subscriber.onError(e); + } catch (IOException e) { + subscriber.onError(e); + } + } + + private void requestAll() { + requested.set(Long.MAX_VALUE); + byte[] buffer = new byte[size]; + try { + if (subscriber.isUnsubscribed()) + return; + int n = is.read(buffer); + while (n != -1 && !subscriber.isUnsubscribed()) { + subscriber.onNext(Arrays.copyOf(buffer, n)); + if (!subscriber.isUnsubscribed()) + n = is.read(buffer); + } + } catch (IOException e) { + subscriber.onError(e); + } + if (subscriber.isUnsubscribed()) + return; + subscriber.onCompleted(); + } + + + + private void requestSome(long n) throws IOException { + // back pressure path + // this algorithm copied roughly from + // rxjava/OnSubscribeFromIterable.java + + long previousCount = requested.getAndAdd(n); + if (previousCount == 0) { byte[] buffer = new byte[size]; - try { - if (o.isUnsubscribed()) + while (true) { + long r = requested.get(); + long numToEmit = r; + + //emit numToEmit + + if (subscriber.isUnsubscribed()) return; - int n = i.read(buffer); - while (n != -1 && !o.isUnsubscribed()) { - o.onNext(Arrays.copyOf(buffer, n)); - if (!o.isUnsubscribed()) - n = i.read(buffer); + int numRead; + if (numToEmit>0) + numRead = is.read(buffer); + else + numRead = 0; + while (numRead != -1 && !subscriber.isUnsubscribed() && numToEmit>0) { + subscriber.onNext(Arrays.copyOf(buffer, numRead)); + numToEmit--; + if (numToEmit >0 && !subscriber.isUnsubscribed()) + numRead = is.read(buffer); } - } catch (IOException e) { - o.onError(e); + + //check if we have finished + if (numRead == -1 && !subscriber.isUnsubscribed()) + subscriber.onCompleted(); + else if (subscriber.isUnsubscribed()) + return; + else if (requested.addAndGet(-r) == 0) + return; } - if (o.isUnsubscribed()) - return; - o.onCompleted(); } - }); + } } /** @@ -176,7 +257,8 @@ public void call(Subscriber o) { n = i.read(buffer); while (n != -1 && !o.isUnsubscribed()) { o.onNext(new String(buffer, 0, n)); - n = i.read(buffer); + if (!o.isUnsubscribed()) + n = i.read(buffer); } } catch (IOException e) { o.onError(e); diff --git a/src/test/java/rx/observables/StringObservableTest.java b/src/test/java/rx/observables/StringObservableTest.java index 3a3cb82..affa902 100644 --- a/src/test/java/rx/observables/StringObservableTest.java +++ b/src/test/java/rx/observables/StringObservableTest.java @@ -36,8 +36,10 @@ import static rx.observables.StringObservable.using; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.FilterReader; import java.io.IOException; +import java.io.InputStreamReader; import java.io.Reader; import java.io.StringReader; import java.nio.charset.Charset; @@ -54,6 +56,7 @@ import rx.Observer; import rx.functions.Func1; import rx.observables.StringObservable.UnsafeFunc0; +import rx.Subscriber; import rx.observers.TestObserver; import rx.observers.TestSubscriber; @@ -296,6 +299,72 @@ public synchronized int read(byte[] b, int off, int len) { assertEquals(1, numReads.get()); } + @Test + public void testFromInputStreamWithBackpressureShouldTakeFourRequests() { + final byte[] inBytes = "tester".getBytes(); + final ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); + final AtomicInteger requestCount = new AtomicInteger(0); + subscribeToInputStream(inBytes, 2, outBytes, 1, requestCount); + assertArrayEquals(inBytes, outBytes.toByteArray()); + assertEquals(4, requestCount.get()); + } + + @Test + public void testFromInputStreamWithBackpressureRequestingMoreThanExist() { + final byte[] inBytes = "tester".getBytes(); + final ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); + final AtomicInteger requestCount = new AtomicInteger(0); + subscribeToInputStream(inBytes, 32, outBytes, 200, requestCount); + assertArrayEquals(inBytes, outBytes.toByteArray()); + assertEquals(2, requestCount.get()); + } + + @Test + public void testFromEmptyInputStreamWithBackpressure() { + final byte[] inBytes = "".getBytes(); + final ByteArrayOutputStream outBytes = new ByteArrayOutputStream(); + final AtomicInteger requestCount = new AtomicInteger(0); + subscribeToInputStream(inBytes, 32, outBytes, 1, requestCount); + assertArrayEquals(inBytes, outBytes.toByteArray()); + assertEquals(1, requestCount.get()); + } + + private static void subscribeToInputStream(byte[] inBytes, int bufferSize, + final ByteArrayOutputStream outBytes, final int requestSize, + final AtomicInteger requestCount) { + + from(new ByteArrayInputStream(inBytes), bufferSize).subscribe( + new Subscriber() { + + @Override + public void onStart() { + request(requestSize); + requestCount.incrementAndGet(); + } + + @Override + public void onCompleted() { + + } + + @Override + public void onError(Throwable e) { + throw new RuntimeException(e); + } + + @Override + public void onNext(byte[] t) { + try { + outBytes.write(t); + } catch (IOException e) { + throw new RuntimeException(e); + } + request(requestSize); + requestCount.incrementAndGet(); + } + }); + } + @Test public void testFromReader() { final String inStr = "test"; @@ -303,6 +372,23 @@ public void testFromReader() { assertNotSame(inStr, outStr); assertEquals(inStr, outStr); } + + @Test + public void testFromReaderWillUnsubscribeBeforeCallingNextRead() { + final byte[] inBytes = "test".getBytes(); + final AtomicInteger numReads = new AtomicInteger(0); + ByteArrayInputStream is = new ByteArrayInputStream(inBytes) { + + @Override + public synchronized int read(byte[] b, int off, int len) { + numReads.incrementAndGet(); + return super.read(b, off, len); + } + }; + StringObservable.from(new InputStreamReader(is)).first().toBlocking() + .single(); + assertEquals(1, numReads.get()); + } @Test public void testByLine() {