diff --git a/generator/client/golang/templates/queryx/db.go b/generator/client/golang/templates/queryx/db.go index f37b7902..ffbbed04 100644 --- a/generator/client/golang/templates/queryx/db.go +++ b/generator/client/golang/templates/queryx/db.go @@ -21,14 +21,60 @@ func NewAdapter(db DB) *Adapter { } func (a *Adapter) Query(query string, args ...interface{}) *Rows { + matched1, err := regexp.MatchString(`.* IN (.*?)`, query) + if err != nil { + return &Rows{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } + } + matched2, err := regexp.MatchString(`.* in (.*?)`, query) + if err != nil { + return &Rows{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } + } + if matched1 || matched2 { + query, args, err = In(query, args...) + if err != nil { + return &Rows{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } + } + } + query, args = rebind(query, args) + rows, err := a.db.Query(query, args...) + if err != nil { + return &Rows{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } + } return &Rows{ + rows: rows, adapter: a, query: query, args: args, + err: nil, } } type Rows struct { + rows *sql.Rows adapter *Adapter query string args []interface{} @@ -39,72 +85,90 @@ func (r *Rows) Scan(v interface{}) error { if r.err != nil { return r.err } - var err error - query, args := r.query, r.args - matched1, err := regexp.MatchString(`.* IN (.*?)`, query) - if err != nil { - return err - } - matched2, err := regexp.MatchString(`.* in (.*?)`, query) - if err != nil { - return err - } - if matched1 || matched2 { - query, args, err = In(query, args...) - if err != nil { - return err - } - } - query, args = rebind(query, args) - rows, err := r.adapter.db.Query(query, args...) - if err != nil { - return err - } - err = ScanSlice(rows, v) + err := ScanSlice(r.rows, v) if err != nil { return err } return err } +func (r *Rows) Rows() (*sql.Rows, error) { + return r.rows, r.err +} + type Row struct { + rows *sql.Rows adapter *Adapter query string args []interface{} + err error } func (r *Row) Scan(v interface{}) error { - query, args := r.query, r.args - matched1, err := regexp.MatchString(`.* IN (.*?)`, query) + if r.err != nil { + return r.err + } + err := ScanOne(r.rows, v) if err != nil { return err } + return err +} + +func (r *Row) Row() (*sql.Rows, error) { + return r.rows, r.err +} + +func (a *Adapter) QueryOne(query string, args ...interface{}) *Row { + matched1, err := regexp.MatchString(`.* IN (.*?)`, query) + if err != nil { + return &Row{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } + } matched2, err := regexp.MatchString(`.* in (.*?)`, query) if err != nil { - return err + return &Row{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } } if matched1 || matched2 { query, args, err = In(query, args...) if err != nil { - return err + return &Row{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } } } query, args = rebind(query, args) - rows, err := r.adapter.db.Query(query, args...) + rows, err := a.db.Query(query, args...) if err != nil { - return err - } - err = ScanOne(rows, v) - if err != nil { - return err + return &Row{ + rows: nil, + adapter: a, + query: query, + args: args, + err: err, + } } - return err -} -func (a *Adapter) QueryOne(query string, args ...interface{}) *Row { return &Row{ + rows: rows, adapter: a, query: query, args: args, + err: err, } } diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 60aa777a..3817be18 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -91,6 +91,46 @@ func TestInsertAll(t *testing.T) { require.Equal(t, int64(2), inserted) } +func TestRow(t *testing.T) { + _, err := c.QueryUser().DeleteAll() + require.NoError(t, err) + + user, err := c.QueryUser().Create(c.ChangeUser().SetName("test_row").SetAge(12)) + require.NoError(t, err) + + rows, err := c.QueryOne("select name from users where id=?", user.ID).Row() + require.NoError(t, err) + var name string + if rows.Next() { + err := rows.Scan(&name) + require.NoError(t, err) + } + err = rows.Close() + require.NoError(t, err) + require.Equal(t, "test_row", name) +} + +func TestRows(t *testing.T) { + _, err := c.QueryUser().DeleteAll() + require.NoError(t, err) + + _, err = c.QueryUser().InsertAll([]*queryx.UserChange{c.ChangeUser().SetName("test_row1").SetAge(12), c.ChangeUser().SetName("test_row2").SetAge(12)}) + require.NoError(t, err) + + rows, err := c.Query("select name from users where age=?", 12).Rows() + require.NoError(t, err) + var name string + values := make([]string, 0) + for rows.Next() { + err := rows.Scan(&name) + require.NoError(t, err) + values = append(values, name) + } + err = rows.Close() + require.NoError(t, err) + require.Equal(t, []string{"test_row1", "test_row2"}, values) +} + func TestCreateEmpty(t *testing.T) { tag, err := c.QueryTag().Create(nil) require.NoError(t, err)