diff --git a/ext/mysql2/result.c b/ext/mysql2/result.c index 2a1967c1c..a054f9878 100644 --- a/ext/mysql2/result.c +++ b/ext/mysql2/result.c @@ -66,6 +66,8 @@ typedef struct { VALUE block_given; } result_each_args; +typedef VALUE (*fetch_row_func_t)(VALUE, MYSQL_FIELD *fields, const result_each_args *args); + VALUE cBigDecimal, cDateTime, cDate; static VALUE cMysql2Result; static VALUE opt_decimal_zero, opt_float_zero, opt_time_year, opt_time_month, opt_utc_offset; @@ -755,8 +757,132 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) { return wrapper->fields; } +static result_each_args rb_mysql_row_query_options(VALUE self, VALUE opts) { + ID dbTz, appTz; + VALUE defaults; + result_each_args args; + + defaults = rb_iv_get(self, "@query_options"); + Check_Type(defaults, T_HASH); + if (!NIL_P(opts)) { + opts = rb_funcall(defaults, intern_merge, 1, opts); + } else { + opts = defaults; + } + + args.symbolizeKeys = RTEST(rb_hash_aref(opts, sym_symbolize_keys)); + args.asArray = rb_hash_aref(opts, sym_as) == sym_array; + args.castBool = RTEST(rb_hash_aref(opts, sym_cast_booleans)); + args.cacheRows = RTEST(rb_hash_aref(opts, sym_cache_rows)); + args.cast = RTEST(rb_hash_aref(opts, sym_cast)); + args.block_given = Qnil; + + dbTz = rb_hash_aref(opts, sym_database_timezone); + if (dbTz == sym_local) { + args.db_timezone = intern_local; + } else if (dbTz == sym_utc) { + args.db_timezone = intern_utc; + } else { + if (!NIL_P(dbTz)) { + rb_warn(":database_timezone option must be :utc or :local - defaulting to :local"); + } + args.db_timezone = intern_local; + } + + appTz = rb_hash_aref(opts, sym_application_timezone); + if (appTz == sym_local) { + args.app_timezone = intern_local; + } else if (appTz == sym_utc) { + args.app_timezone = intern_utc; + } else { + args.app_timezone = Qnil; + } + + return args; +} + +static VALUE rb_mysql_result_element(int argc, VALUE * argv, VALUE self) { + result_each_args args; + MYSQL_FIELD *fields = NULL; + VALUE seek, count, row, rows; + long i, c_seek, c_count = 0; + VALUE block, opts; + fetch_row_func_t fetch_row_func; + + GET_RESULT(self); + + rb_scan_args(argc, argv, "12&", &seek, &count, &opts, &block); + + /* If the second arg is a hash, it's the opts and there's no count */ + if (TYPE(count) == T_HASH) { + opts = count; + count = Qnil; + } + + c_seek = NUM2LONG(seek); + if (!NIL_P(count)) { + c_count = NUM2LONG(count); + /* Special case: ary[x, 0] returns []*/ + if (!c_count) return rb_ary_new(); + } + + args = rb_mysql_row_query_options(self, opts); + args.block_given = block; + + if (wrapper->is_streaming) { + rb_raise(cMysql2Error, "Element reference operator #[] cannot be used in streaming mode."); + } + + if (!wrapper->numberOfRows) { + wrapper->numberOfRows = mysql_num_rows(wrapper->result); + } + + /* count back from the end if passed a negative number */ + if (c_seek < 0) { + c_seek = wrapper->numberOfRows + c_seek; + } + + /* negative offset was too big */ + if (c_seek < 0) { + return Qnil; + /* rb_raise(cMysql2Error, "Out of range: offset %ld is beyond %lu rows (offset begins at 0).", c_seek, wrapper->numberOfRows); */ + } + + if (wrapper->numberOfRows <= (unsigned long)c_seek) { + if (!c_count) return Qnil; + else return rb_ary_new(); + /* rb_raise(cMysql2Error, "Out of range: offset %ld is beyond %lu rows (offset begins at 0).", c_seek, wrapper->numberOfRows); */ + } + + mysql_data_seek(wrapper->result, c_seek); + fields = mysql_fetch_fields(wrapper->result); + + if (wrapper->stmt) { + fetch_row_func = rb_mysql_result_fetch_row_stmt; + } else { + fetch_row_func = rb_mysql_result_fetch_row; + } + + if (!c_count) { + return fetch_row_func(self, fields, &args); + } + + /* given ary = [1, 2, 3] then ary[1, 100] returns [2, 3] */ + if ((unsigned long)(c_seek + c_count) > wrapper->numberOfRows) { + c_count = wrapper->numberOfRows - c_seek; + } + + /* return an array! */ + rows = rb_ary_new2(c_count); + for (i = 0; i < c_count; i++) { + row = fetch_row_func(self, fields, &args); + rb_ary_store(rows, i, row); + } + return rows; +} + static VALUE rb_mysql_result_each_(VALUE self, - VALUE(*fetch_row_func)(VALUE, MYSQL_FIELD *fields, const result_each_args *args), + fetch_row_func_t fetch_row_func, const result_each_args *args) { unsigned long i; @@ -846,58 +972,14 @@ static VALUE rb_mysql_result_each_(VALUE self, static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { result_each_args args; - VALUE defaults, opts, block, (*fetch_row_func)(VALUE, MYSQL_FIELD *fields, const result_each_args *args); - ID db_timezone, app_timezone, dbTz, appTz; - int symbolizeKeys, asArray, castBool, cacheRows, cast; + VALUE opts, block; + fetch_row_func_t fetch_row_func; GET_RESULT(self); - defaults = rb_iv_get(self, "@query_options"); - Check_Type(defaults, T_HASH); - if (rb_scan_args(argc, argv, "01&", &opts, &block) == 1) { - opts = rb_funcall(defaults, intern_merge, 1, opts); - } else { - opts = defaults; - } - - symbolizeKeys = RTEST(rb_hash_aref(opts, sym_symbolize_keys)); - asArray = rb_hash_aref(opts, sym_as) == sym_array; - castBool = RTEST(rb_hash_aref(opts, sym_cast_booleans)); - cacheRows = RTEST(rb_hash_aref(opts, sym_cache_rows)); - cast = RTEST(rb_hash_aref(opts, sym_cast)); - - if (wrapper->is_streaming && cacheRows) { - rb_warn(":cache_rows is ignored if :stream is true"); - } - - if (wrapper->stmt && !cacheRows && !wrapper->is_streaming) { - rb_warn(":cache_rows is forced for prepared statements (if not streaming)"); - } - - if (wrapper->stmt && !cast) { - rb_warn(":cast is forced for prepared statements"); - } - - dbTz = rb_hash_aref(opts, sym_database_timezone); - if (dbTz == sym_local) { - db_timezone = intern_local; - } else if (dbTz == sym_utc) { - db_timezone = intern_utc; - } else { - if (!NIL_P(dbTz)) { - rb_warn(":database_timezone option must be :utc or :local - defaulting to :local"); - } - db_timezone = intern_local; - } - - appTz = rb_hash_aref(opts, sym_application_timezone); - if (appTz == sym_local) { - app_timezone = intern_local; - } else if (appTz == sym_utc) { - app_timezone = intern_utc; - } else { - app_timezone = Qnil; - } + rb_scan_args(argc, argv, "01&", &opts, &block); + args = rb_mysql_row_query_options(self, opts); + args.block_given = block; if (wrapper->lastRowProcessed == 0 && !wrapper->is_streaming) { wrapper->numberOfRows = wrapper->stmt ? mysql_stmt_num_rows(wrapper->stmt) : mysql_num_rows(wrapper->result); @@ -908,16 +990,6 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) { wrapper->rows = rb_ary_new2(wrapper->numberOfRows); } - // Backward compat - args.symbolizeKeys = symbolizeKeys; - args.asArray = asArray; - args.castBool = castBool; - args.cacheRows = cacheRows; - args.cast = cast; - args.db_timezone = db_timezone; - args.app_timezone = app_timezone; - args.block_given = block; - if (wrapper->stmt) { fetch_row_func = rb_mysql_result_fetch_row_stmt; } else { @@ -988,6 +1060,7 @@ void init_mysql2_result() { cDateTime = rb_const_get(rb_cObject, rb_intern("DateTime")); cMysql2Result = rb_define_class_under(mMysql2, "Result", rb_cObject); + rb_define_method(cMysql2Result, "[]", rb_mysql_result_element, -1); rb_define_method(cMysql2Result, "each", rb_mysql_result_each, -1); rb_define_method(cMysql2Result, "fields", rb_mysql_result_fetch_fields, 0); rb_define_method(cMysql2Result, "count", rb_mysql_result_count, 0); diff --git a/spec/mysql2/result_spec.rb b/spec/mysql2/result_spec.rb index b32939753..c195cd24c 100644 --- a/spec/mysql2/result_spec.rb +++ b/spec/mysql2/result_spec.rb @@ -98,6 +98,53 @@ end end + context "#[]" do + it "should return results when accessed by [offset]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[1]).to eql({"col" => 2}) + expect(result[0]).to eql({"col" => 1}) + end + + it "should return results when accessed by negative [offset]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[-1]).to eql({"col" => 2}) + expect(result[-2]).to eql({"col" => 1}) + end + + it "should return array of results when accessed by [offset, count]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[1, 1]).to eql([{"col" => 2}]) + expect(result[-2, 10]).to eql([{"col" => 1}, {"col" => 2}]) + end + + it "should return nil if we use too large [offset]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[2]).to be_nil + expect(result[200]).to be_nil + end + + it "should return nil if we use too negative [offset]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[-3]).to be_nil + expect(result[-300]).to be_nil + end + + it "should accept hash args in [offset, {:foo => bar}] and [offset, count, {:foo => bar}]" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col" + expect(result[1, {:symbolize_keys => true}]).to eql({:col => 2}) + expect(result[1, 1, {:symbolize_keys => true}]).to eql([{:col => 2}]) + + # This syntax can't be parsed by Ruby 1.8: + # expect(result[1, :symbolize_keys => true]).to eql({:col => 2}) + # expect(result[1, 1, :symbolize_keys => true]).to eql([{:col => 2}]) + end + + it "should throw an exception if we use an [offset] in streaming mode" do + result = @client.query "SELECT 1 AS col UNION SELECT 2 AS col", :stream => true + expect { result[0] }.to raise_exception(Mysql2::Error) + end + end + context "#fields" do before(:each) do @test_result = @client.query("SELECT * FROM mysql2_test ORDER BY id DESC LIMIT 1")