Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 77 additions & 44 deletions ext/zlib/zlib.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ static VALUE rb_gzreader_readlines(int, VALUE*, VALUE);
* - Zlib::MemError
* - Zlib::BufError
* - Zlib::VersionError
* - Zlib::InProgressError
*
* (if you have GZIP_SUPPORT)
* - Zlib::GzipReader
Expand All @@ -304,7 +305,7 @@ void Init_zlib(void);
/*--------- Exceptions --------*/

static VALUE cZError, cStreamEnd, cNeedDict;
static VALUE cStreamError, cDataError, cMemError, cBufError, cVersionError;
static VALUE cStreamError, cDataError, cMemError, cBufError, cVersionError, cInProgressError;

static void
raise_zlib_error(int err, const char *msg)
Expand Down Expand Up @@ -557,14 +558,15 @@ struct zstream {
} *func;
};

#define ZSTREAM_FLAG_READY 0x1
#define ZSTREAM_FLAG_IN_STREAM 0x2
#define ZSTREAM_FLAG_FINISHED 0x4
#define ZSTREAM_FLAG_CLOSING 0x8
#define ZSTREAM_FLAG_GZFILE 0x10 /* disallows yield from expand_buffer for
#define ZSTREAM_FLAG_READY (1 << 0)
#define ZSTREAM_FLAG_IN_STREAM (1 << 1)
#define ZSTREAM_FLAG_FINISHED (1 << 2)
#define ZSTREAM_FLAG_CLOSING (1 << 3)
#define ZSTREAM_FLAG_GZFILE (1 << 4) /* disallows yield from expand_buffer for
gzip*/
#define ZSTREAM_REUSE_BUFFER 0x20
#define ZSTREAM_FLAG_UNUSED 0x40
#define ZSTREAM_REUSE_BUFFER (1 << 5)
#define ZSTREAM_IN_PROGRESS (1 << 6)
#define ZSTREAM_FLAG_UNUSED (1 << 7)

#define ZSTREAM_READY(z) ((z)->flags |= ZSTREAM_FLAG_READY)
#define ZSTREAM_IS_READY(z) ((z)->flags & ZSTREAM_FLAG_READY)
Expand Down Expand Up @@ -593,7 +595,9 @@ static const struct zstream_funcs inflate_funcs = {
};

struct zstream_run_args {
struct zstream * z;
struct zstream *const z;
Bytef *src;
long len;
int flush; /* stream flush value for inflate() or deflate() */
int interrupt; /* stop processing the stream and return to ruby */
int jump_state; /* for buffer expansion block break or exception */
Expand Down Expand Up @@ -894,7 +898,6 @@ zstream_discard_input(struct zstream *z, long len)
}
rb_str_resize(z->input, newlen);
if (newlen == 0) {
rb_gc_force_recycle(z->input);
z->input = Qnil;
}
else {
Expand Down Expand Up @@ -1059,19 +1062,18 @@ zstream_unblock_func(void *ptr)
args->interrupt = 1;
}

static void
zstream_run0(struct zstream *z, Bytef *src, long len, int flush)
static VALUE
zstream_run_try(VALUE value_arg)
{
struct zstream_run_args args;
struct zstream_run_args *args = (struct zstream_run_args *)value_arg;
struct zstream *z = args->z;
Bytef *src = args->src;
long len = args->len;
int flush = args->flush;

int err;
VALUE old_input = Qnil;

args.z = z;
args.flush = flush;
args.interrupt = 0;
args.jump_state = 0;
args.stream_output = !ZSTREAM_IS_GZFILE(z) && rb_block_given_p();

if (NIL_P(z->input) && len == 0) {
z->stream.next_in = (Bytef*)"";
z->stream.avail_in = 0;
Expand All @@ -1093,17 +1095,17 @@ zstream_run0(struct zstream *z, Bytef *src, long len, int flush)

loop:
#ifndef RB_NOGVL_UBF_ASYNC_SAFE
err = (int)(VALUE)rb_thread_call_without_gvl(zstream_run_func, (void *)&args,
zstream_unblock_func, (void *)&args);
err = (int)(VALUE)rb_thread_call_without_gvl(zstream_run_func, (void *)args,
zstream_unblock_func, (void *)args);
#else
err = (int)(VALUE)rb_nogvl(zstream_run_func, (void *)&args,
zstream_unblock_func, (void *)&args,
err = (int)(VALUE)rb_nogvl(zstream_run_func, (void *)args,
zstream_unblock_func, (void *)args,
RB_NOGVL_UBF_ASYNC_SAFE);
#endif

/* retry if no exception is thrown */
if (err == Z_OK && args.interrupt) {
args.interrupt = 0;
if (err == Z_OK && args->interrupt) {
args->interrupt = 0;
goto loop;
}

Expand Down Expand Up @@ -1137,37 +1139,54 @@ zstream_run0(struct zstream *z, Bytef *src, long len, int flush)
}
if (!NIL_P(old_input)) {
rb_str_resize(old_input, 0);
rb_gc_force_recycle(old_input);
}

if (args.jump_state)
rb_jump_tag(args.jump_state);
if (args->jump_state)
rb_jump_tag(args->jump_state);

return Qnil;
}

struct zstream_run_synchronized_args {
struct zstream *z;
Bytef *src;
long len;
int flush;
};
static VALUE
zstream_run_ensure(VALUE value_arg)
{
struct zstream_run_args *args = (struct zstream_run_args *)value_arg;

/* Remove ZSTREAM_IN_PROGRESS flag to signal that this zstream is not in use. */
args->z->flags &= ~ZSTREAM_IN_PROGRESS;

return Qnil;
}

static VALUE
zstream_run_synchronized(VALUE value_arg)
{
struct zstream_run_synchronized_args *run_args = (struct zstream_run_synchronized_args *)value_arg;
zstream_run0(run_args->z, run_args->src, run_args->len, run_args->flush);
struct zstream_run_args *args = (struct zstream_run_args *)value_arg;

/* Cannot start zstream while it is in progress. */
if (args->z->flags & ZSTREAM_IN_PROGRESS) {
rb_raise(cInProgressError, "zlib stream is in progress");
}
args->z->flags |= ZSTREAM_IN_PROGRESS;

rb_ensure(zstream_run_try, value_arg, zstream_run_ensure, value_arg);

return Qnil;
}

static void
zstream_run(struct zstream *z, Bytef *src, long len, int flush)
{
struct zstream_run_synchronized_args run_args;
run_args.z = z;
run_args.src = src;
run_args.len = len;
run_args.flush = flush;
rb_mutex_synchronize(z->mutex, zstream_run_synchronized, (VALUE)&run_args);
struct zstream_run_args args = {
.z = z,
.src = src,
.len = len,
.flush = flush,
.interrupt = 0,
.jump_state = 0,
.stream_output = !ZSTREAM_IS_GZFILE(z) && rb_block_given_p(),
};
rb_mutex_synchronize(z->mutex, zstream_run_synchronized, (VALUE)&args);
}

static VALUE
Expand Down Expand Up @@ -2906,8 +2925,6 @@ gzfile_readpartial(struct gzfile *gz, long len, VALUE outbuf)
if (!NIL_P(outbuf)) {
rb_str_resize(outbuf, RSTRING_LEN(dst));
memcpy(RSTRING_PTR(outbuf), RSTRING_PTR(dst), RSTRING_LEN(dst));
rb_str_resize(dst, 0);
rb_gc_force_recycle(dst);
dst = outbuf;
}
return dst;
Expand Down Expand Up @@ -4619,6 +4636,7 @@ Init_zlib(void)
cMemError = rb_define_class_under(mZlib, "MemError", cZError);
cBufError = rb_define_class_under(mZlib, "BufError", cZError);
cVersionError = rb_define_class_under(mZlib, "VersionError", cZError);
cInProgressError = rb_define_class_under(mZlib, "InProgressError", cZError);

rb_define_module_function(mZlib, "zlib_version", rb_zlib_version, 0);
rb_define_module_function(mZlib, "adler32", rb_zlib_adler32, -1);
Expand Down Expand Up @@ -4926,6 +4944,7 @@ Init_zlib(void)
* - Zlib::MemError
* - Zlib::BufError
* - Zlib::VersionError
* - Zlib::InProgressError
*
*/

Expand Down Expand Up @@ -5000,6 +5019,20 @@ Init_zlib(void)
*
*/

/*
* Document-class: Zlib::InProgressError
*
* Subclass of Zlib::Error. This error is raised when the zlib
* stream is currently in progress.
*
* For example:
*
* inflater = Zlib::Inflate.new
* inflater.inflate(compressed) do
* inflater.inflate(compressed) # Raises Zlib::InProgressError
* end
*/

/*
* Document-class: Zlib::GzipFile::Error
*
Expand Down
10 changes: 8 additions & 2 deletions test/zlib/test_zlib.rb
Original file line number Diff line number Diff line change
Expand Up @@ -538,30 +538,36 @@ def test_multithread_inflate
end

def test_recursive_deflate
original_gc_stress = GC.stress
GC.stress = true
zd = Zlib::Deflate.new

s = SecureRandom.random_bytes(1024**2)
assert_raise(Zlib::BufError) do
assert_raise(Zlib::InProgressError) do
zd.deflate(s) do
zd.deflate(s)
end
end
ensure
GC.stress = original_gc_stress
zd&.finish
zd&.close
end

def test_recursive_inflate
original_gc_stress = GC.stress
GC.stress = true
zi = Zlib::Inflate.new

s = Zlib.deflate(SecureRandom.random_bytes(1024**2))

assert_raise(Zlib::DataError) do
assert_raise(Zlib::InProgressError) do
zi.inflate(s) do
zi.inflate(s)
end
end
ensure
GC.stress = original_gc_stress
zi&.close
end
end
Expand Down