diff --git a/.rubocop_todo.yml b/.rubocop_todo.yml index 3da4399a2..456fe2d28 100644 --- a/.rubocop_todo.yml +++ b/.rubocop_todo.yml @@ -22,7 +22,7 @@ Metrics/AbcSize: # Offense count: 31 # Configuration parameters: CountComments, ExcludedMethods. Metrics/BlockLength: - Max: 850 + Max: 855 # Offense count: 1 # Configuration parameters: CountBlocks. diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index d12f2c180..33ad9ace4 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -20,6 +20,15 @@ static VALUE sym_id, sym_version, sym_header_version, sym_async, sym_symbolize_k static VALUE sym_no_good_index_used, sym_no_index_used, sym_query_was_slow; static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args; +/* Rather than including violite.h to be able to look into the net.vio structure + * to call its has_data function pointer, we borrow the definitions of just two + * functions that understand this structure. + * + * Definitions are from vio_priv.h + */ +my_bool vio_ssl_has_data(void *); +my_bool vio_buff_has_data(void *); + #define REQUIRE_INITIALIZED(wrapper) \ if (!wrapper->initialized) { \ rb_raise(cMysql2Error, "MySQL client is not initialized"); \ @@ -27,8 +36,20 @@ static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args #if defined(HAVE_MYSQL_NET_VIO) || defined(HAVE_ST_NET_VIO) #define CONNECTED(wrapper) (wrapper->client->net.vio != NULL && wrapper->client->net.fd != -1) + + #define HAS_DATA(wrapper) \ + mysql_get_ssl_cipher(wrapper->client) \ + ? !vio_ssl_has_data(wrapper->client->net.vio) \ + : !vio_buff_has_data(wrapper->client->net.vio) + #elif defined(HAVE_MYSQL_NET_PVIO) || defined(HAVE_ST_NET_PVIO) #define CONNECTED(wrapper) (wrapper->client->net.pvio != NULL && wrapper->client->net.fd != -1) + + #define HAS_DATA(wrapper) \ + mysql_get_ssl_cipher(wrapper->client) \ + ? !pvio_ssl_has_data(wrapper->client->net.pvio) \ + : !pvio_buff_has_data(wrapper->client->net.pvio) + #endif #define REQUIRE_CONNECTED(wrapper) \ @@ -631,36 +652,14 @@ static VALUE disconnect_and_raise(VALUE self, VALUE error) { rb_exc_raise(error); } -static VALUE do_query(void *args) { - struct async_query_args *async_args = args; - struct timeval tv; - struct timeval *tvp; - long int sec; +static void wait_for_fd(int fd, struct timeval *tvp) { int retval; - VALUE read_timeout; - - read_timeout = rb_iv_get(async_args->self, "@read_timeout"); - - tvp = NULL; - if (!NIL_P(read_timeout)) { - Check_Type(read_timeout, T_FIXNUM); - tvp = &tv; - sec = FIX2INT(read_timeout); - /* TODO: support partial seconds? - also, this check is here for sanity, we also check up in Ruby */ - if (sec >= 0) { - tvp->tv_sec = sec; - } else { - rb_raise(cMysql2Error, "read_timeout must be a positive integer, you passed %ld", sec); - } - tvp->tv_usec = 0; - } for(;;) { - retval = rb_wait_for_single_fd(async_args->fd, RB_WAITFD_IN, tvp); + retval = rb_wait_for_single_fd(fd, RB_WAITFD_IN, tvp); if (retval == 0) { - rb_raise(cMysql2TimeoutError, "Timeout waiting for a response from the last query. (waited %d seconds)", FIX2INT(read_timeout)); + rb_raise(cMysql2TimeoutError, "Timeout waiting for a response from the last query. (waited %ld seconds)", tvp->tv_sec); } if (retval < 0) { @@ -671,6 +670,38 @@ static VALUE do_query(void *args) { break; } } +} + +static struct timeval *get_read_timeout(VALUE self, struct timeval *tvp) { + long int sec; + VALUE read_timeout; + + read_timeout = rb_iv_get(self, "@read_timeout"); + + if (NIL_P(read_timeout)) { + return NULL; + } + + Check_Type(read_timeout, T_FIXNUM); + sec = FIX2INT(read_timeout); + /* TODO: support partial seconds? + also, this check is here for sanity, we also check up in Ruby */ + if (sec < 0) { + rb_raise(cMysql2Error, "read_timeout must be a positive integer, you passed %ld", sec); + } + + tvp->tv_sec = sec; + tvp->tv_usec = 0; + return tvp; +} + +static VALUE do_query(void *args) { + struct async_query_args *async_args; + struct timeval tv, *tvp; + + async_args = (struct async_query_args *)args; + tvp = get_read_timeout(async_args->self, &tv); + wait_for_fd(async_args->fd, tvp); return Qnil; } @@ -1129,6 +1160,12 @@ static VALUE rb_mysql_client_more_results(VALUE self) return Qtrue; } +static void *nogvl_next_result(void *ptr) { + mysql_client_wrapper *wrapper = ptr; + + return (void *)INT2NUM(mysql_next_result(wrapper->client)); +} + /* call-seq: * client.next_result * @@ -1137,17 +1174,36 @@ static VALUE rb_mysql_client_more_results(VALUE self) */ static VALUE rb_mysql_client_next_result(VALUE self) { - int ret; - GET_CLIENT(self); - ret = mysql_next_result(wrapper->client); - if (ret > 0) { - rb_raise_mysql2_error(wrapper); - return Qfalse; - } else if (ret == 0) { - return Qtrue; - } else { - return Qfalse; - } + int ret; + struct timeval tv, *tvp; + GET_CLIENT(self); + + if (mysql_more_results(wrapper->client) == 0) + return Qfalse; + + /* The underlying client library may have pre-read the results from the next + * query, so wait_for_fd would not be triggered. Instead we will ask whether + * the net.vio structure has additional data that hasn't been parsed. + * + * Hack: use knowledge of whether the connecti is SSL or not to call the + * appropriate has_data function and pass net.vio as an opaque structure. + */ + if (HAS_DATA(wrapper)) { + tvp = get_read_timeout(self, &tv); + wait_for_fd(wrapper->client->net.fd, tvp); + } + + VALUE v = (VALUE)rb_thread_call_without_gvl(nogvl_next_result, wrapper, RUBY_UBF_IO, 0); + ret = NUM2INT(v); + + if (ret > 0) { + rb_raise_mysql2_error(wrapper); + return Qfalse; + } else if (ret == 0) { + return Qtrue; + } else { + return Qfalse; + } } /* call-seq: diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index 00d9c17db..f9fe24eb7 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -774,6 +774,17 @@ def run_gc expect(@multi_client.more_results?).to be false end + it "should allow for interruption" do + time_top = Time.now.to_f + expect do + Timeout.timeout(0.2, ArgumentError) do + @multi_client.query('SELECT 1; SELECT SLEEP(2)') + @multi_client.next_result + end + end.to raise_error(ArgumentError) + expect(Time.now.to_f - time_top).to be <= 0.5 + end + it "#more_results? should work with stored procedures" do @multi_client.query("DROP PROCEDURE IF EXISTS test_proc") @multi_client.query("CREATE PROCEDURE test_proc() BEGIN SELECT 1 AS 'set_1'; SELECT 2 AS 'set_2'; END")