diff --git a/Gopkg.lock b/Gopkg.lock index 6551354a0..1fea221b7 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -299,12 +299,14 @@ [[projects]] name = "github.com/go-xorm/builder" packages = ["."] - revision = "488224409dd8aa2ce7a5baf8d10d55764a913738" + revision = "dc8bf48f58fab2b4da338ffd25191905fd741b8f" + version = "v0.3.0" [[projects]] name = "github.com/go-xorm/core" packages = ["."] - revision = "cb1d0ca71f42d3ee1bf4aba7daa16099bc31a7e9" + revision = "c10e21e7e1cec20e09398f2dfae385e58c8df555" + version = "v0.6.0" [[projects]] name = "github.com/go-xorm/tidb" @@ -314,7 +316,7 @@ [[projects]] name = "github.com/go-xorm/xorm" packages = ["."] - revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" + revision = "ad69f7d8f0861a29438154bb0a20b60501298480" [[projects]] branch = "master" @@ -873,6 +875,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" + inputs-digest = "3b587a036aaf09514228ead18e7fd93e9ee1d14d4e56715bb2f197d5f27259d1" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 1019888c0..97f736b5f 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -38,7 +38,7 @@ ignored = ["google.golang.org/appengine*"] [[override]] name = "github.com/go-xorm/xorm" #version = "0.6.5" - revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" + revision = "ad69f7d8f0861a29438154bb0a20b60501298480" [[override]] name = "github.com/gorilla/mux" diff --git a/models/issue.go b/models/issue.go index d97266b4e..99f379f58 100644 --- a/models/issue.go +++ b/models/issue.go @@ -1283,7 +1283,7 @@ func getParticipantsByIssueID(e Engine, issueID int64) ([]*User, error) { And("`comment`.type = ?", CommentTypeComment). And("`user`.is_active = ?", true). And("`user`.prohibit_login = ?", false). - Join("INNER", "user", "`user`.id = `comment`.poster_id"). + Join("INNER", "`user`", "`user`.id = `comment`.poster_id"). Distinct("poster_id"). Find(&userIDs); err != nil { return nil, fmt.Errorf("get poster IDs: %v", err) diff --git a/models/issue_list.go b/models/issue_list.go index 05130a6ee..fdfca71af 100644 --- a/models/issue_list.go +++ b/models/issue_list.go @@ -166,7 +166,7 @@ func (issues IssueList) loadAssignees(e Engine) error { var assignees = make(map[int64][]*User, len(issues)) rows, err := e.Table("issue_assignees"). - Join("INNER", "user", "`user`.id = `issue_assignees`.assignee_id"). + Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id"). In("`issue_assignees`.issue_id", issues.getIssueIDs()). Rows(new(AssigneeIssue)) if err != nil { diff --git a/models/issue_watch.go b/models/issue_watch.go index 3e7d24821..579c91547 100644 --- a/models/issue_watch.go +++ b/models/issue_watch.go @@ -67,7 +67,7 @@ func getIssueWatchers(e Engine, issueID int64) (watches []*IssueWatch, err error Where("`issue_watch`.issue_id = ?", issueID). And("`user`.is_active = ?", true). And("`user`.prohibit_login = ?", false). - Join("INNER", "user", "`user`.id = `issue_watch`.user_id"). + Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id"). Find(&watches) return } diff --git a/models/org.go b/models/org.go index 23f6c58bf..bd5fc825b 100644 --- a/models/org.go +++ b/models/org.go @@ -383,7 +383,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) { func GetOrgUsersByUserID(uid int64, all bool) ([]*OrgUser, error) { ous := make([]*OrgUser, 0, 10) sess := x. - Join("LEFT", "user", "`org_user`.org_id=`user`.id"). + Join("LEFT", "`user`", "`org_user`.org_id=`user`.id"). Where("`org_user`.uid=?", uid) if !all { // Only show public organizations @@ -575,7 +575,7 @@ func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team, return teams, e. Where("`team_user`.org_id = ?", org.ID). Join("INNER", "team_user", "`team_user`.team_id = team.id"). - Join("INNER", "user", "`user`.id=team_user.uid"). + Join("INNER", "`user`", "`user`.id=team_user.uid"). And("`team_user`.uid = ?", userID). Asc("`user`.name"). Cols(cols...). diff --git a/models/repo.go b/models/repo.go index b126c7e3b..12ded1b0a 100644 --- a/models/repo.go +++ b/models/repo.go @@ -1954,7 +1954,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error { func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) { var repo Repository has, err := x.Select("repository.*"). - Join("INNER", "user", "`user`.id = repository.owner_id"). + Join("INNER", "`user`", "`user`.id = repository.owner_id"). Where("repository.lower_name = ?", strings.ToLower(repoName)). And("`user`.lower_name = ?", strings.ToLower(ownerName)). Get(&repo) diff --git a/models/repo_watch.go b/models/repo_watch.go index 8019027c1..95c7e44e9 100644 --- a/models/repo_watch.go +++ b/models/repo_watch.go @@ -54,7 +54,7 @@ func getWatchers(e Engine, repoID int64) ([]*Watch, error) { return watches, e.Where("`watch`.repo_id=?", repoID). And("`user`.is_active=?", true). And("`user`.prohibit_login=?", false). - Join("INNER", "user", "`user`.id = `watch`.user_id"). + Join("INNER", "`user`", "`user`.id = `watch`.user_id"). Find(&watches) } diff --git a/models/user.go b/models/user.go index 0b7af8df6..32b9bfec9 100644 --- a/models/user.go +++ b/models/user.go @@ -374,9 +374,9 @@ func (u *User) GetFollowers(page int) ([]*User, error) { Limit(ItemsPerPage, (page-1)*ItemsPerPage). Where("follow.follow_id=?", u.ID) if setting.UsePostgreSQL { - sess = sess.Join("LEFT", "follow", `"user".id=follow.user_id`) + sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") } else { - sess = sess.Join("LEFT", "follow", "user.id=follow.user_id") + sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") } return users, sess.Find(&users) } @@ -393,9 +393,9 @@ func (u *User) GetFollowing(page int) ([]*User, error) { Limit(ItemsPerPage, (page-1)*ItemsPerPage). Where("follow.user_id=?", u.ID) if setting.UsePostgreSQL { - sess = sess.Join("LEFT", "follow", `"user".id=follow.follow_id`) + sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") } else { - sess = sess.Join("LEFT", "follow", "user.id=follow.follow_id") + sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") } return users, sess.Find(&users) } diff --git a/vendor/github.com/go-xorm/builder/builder.go b/vendor/github.com/go-xorm/builder/builder.go index 1253b9887..80586ce4e 100644 --- a/vendor/github.com/go-xorm/builder/builder.go +++ b/vendor/github.com/go-xorm/builder/builder.go @@ -4,6 +4,10 @@ package builder +import ( + "fmt" +) + type optype byte const ( @@ -29,6 +33,9 @@ type Builder struct { joins []join inserts Eq updates []Eq + orderBy string + groupBy string + having string } // Select creates a select Builder @@ -67,6 +74,11 @@ func (b *Builder) From(tableName string) *Builder { return b } +// TableName returns the table name +func (b *Builder) TableName() string { + return b.tableName +} + // Into sets insert table name func (b *Builder) Into(tableName string) *Builder { b.tableName = tableName @@ -178,6 +190,33 @@ func (b *Builder) ToSQL() (string, []interface{}, error) { return w.writer.String(), w.args, nil } +// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix +func ConvertPlaceholder(sql, prefix string) (string, error) { + buf := StringBuilder{} + var j, start = 0, 0 + for i := 0; i < len(sql); i++ { + if sql[i] == '?' { + _, err := buf.WriteString(sql[start:i]) + if err != nil { + return "", err + } + start = i + 1 + + _, err = buf.WriteString(prefix) + if err != nil { + return "", err + } + + j = j + 1 + _, err = buf.WriteString(fmt.Sprintf("%d", j)) + if err != nil { + return "", err + } + } + } + return buf.String(), nil +} + // ToSQL convert a builder or condtions to SQL and args func ToSQL(cond interface{}) (string, []interface{}, error) { switch cond.(type) { diff --git a/vendor/github.com/go-xorm/builder/builder_insert.go b/vendor/github.com/go-xorm/builder/builder_insert.go index decec9313..9b213ec73 100644 --- a/vendor/github.com/go-xorm/builder/builder_insert.go +++ b/vendor/github.com/go-xorm/builder/builder_insert.go @@ -15,7 +15,7 @@ func (b *Builder) insertWriteTo(w Writer) error { return errors.New("no table indicated") } if len(b.inserts) <= 0 { - return errors.New("no column to be update") + return errors.New("no column to be insert") } if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { @@ -26,7 +26,9 @@ func (b *Builder) insertWriteTo(w Writer) error { var bs []byte var valBuffer = bytes.NewBuffer(bs) var i = 0 - for col, value := range b.inserts { + + for _, col := range b.inserts.sortedKeys() { + value := b.inserts[col] fmt.Fprint(w, col) if e, ok := value.(expr); ok { fmt.Fprint(valBuffer, e.sql) diff --git a/vendor/github.com/go-xorm/builder/builder_select.go b/vendor/github.com/go-xorm/builder/builder_select.go index 3a3967ccc..deaacbd02 100644 --- a/vendor/github.com/go-xorm/builder/builder_select.go +++ b/vendor/github.com/go-xorm/builder/builder_select.go @@ -34,24 +34,65 @@ func (b *Builder) selectWriteTo(w Writer) error { } } - if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil { + if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil { return err } for _, v := range b.joins { - fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable) + if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil { + return err + } + if err := v.joinCond.WriteTo(w); err != nil { return err } } - if !b.cond.IsValid() { - return nil + if b.cond.IsValid() { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + + if err := b.cond.WriteTo(w); err != nil { + return err + } } - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err + if len(b.groupBy) > 0 { + if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil { + return err + } } - return b.cond.WriteTo(w) + if len(b.having) > 0 { + if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil { + return err + } + } + + if len(b.orderBy) > 0 { + if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil { + return err + } + } + + return nil +} + +// OrderBy orderBy SQL +func (b *Builder) OrderBy(orderBy string) *Builder { + b.orderBy = orderBy + return b +} + +// GroupBy groupby SQL +func (b *Builder) GroupBy(groupby string) *Builder { + b.groupBy = groupby + return b +} + +// Having having SQL +func (b *Builder) Having(having string) *Builder { + b.having = having + return b } diff --git a/vendor/github.com/go-xorm/builder/cond.go b/vendor/github.com/go-xorm/builder/cond.go index 77dd139bf..c0c2553de 100644 --- a/vendor/github.com/go-xorm/builder/cond.go +++ b/vendor/github.com/go-xorm/builder/cond.go @@ -5,7 +5,6 @@ package builder import ( - "bytes" "io" ) @@ -19,15 +18,15 @@ var _ Writer = NewWriter() // BytesWriter implments Writer and save SQL in bytes.Buffer type BytesWriter struct { - writer *bytes.Buffer - buffer []byte + writer *StringBuilder args []interface{} } // NewWriter creates a new string writer func NewWriter() *BytesWriter { - w := &BytesWriter{} - w.writer = bytes.NewBuffer(w.buffer) + w := &BytesWriter{ + writer: &StringBuilder{}, + } return w } diff --git a/vendor/github.com/go-xorm/builder/cond_compare.go b/vendor/github.com/go-xorm/builder/cond_compare.go index e10ef7447..1c293719c 100644 --- a/vendor/github.com/go-xorm/builder/cond_compare.go +++ b/vendor/github.com/go-xorm/builder/cond_compare.go @@ -10,7 +10,13 @@ import "fmt" func WriteMap(w Writer, data map[string]interface{}, op string) error { var args = make([]interface{}, 0, len(data)) var i = 0 - for k, v := range data { + keys := make([]string, 0, len(data)) + for k := range data { + keys = append(keys, k) + } + + for _, k := range keys { + v := data[k] switch v.(type) { case expr: if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { diff --git a/vendor/github.com/go-xorm/builder/cond_eq.go b/vendor/github.com/go-xorm/builder/cond_eq.go index 8777727ff..79d795e6d 100644 --- a/vendor/github.com/go-xorm/builder/cond_eq.go +++ b/vendor/github.com/go-xorm/builder/cond_eq.go @@ -4,7 +4,10 @@ package builder -import "fmt" +import ( + "fmt" + "sort" +) // Incr implements a type used by Eq type Incr int @@ -19,7 +22,8 @@ var _ Cond = Eq{} func (eq Eq) opWriteTo(op string, w Writer) error { var i = 0 - for k, v := range eq { + for _, k := range eq.sortedKeys() { + v := eq[k] switch v.(type) { case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: if err := In(k, v).WriteTo(w); err != nil { @@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond { func (eq Eq) IsValid() bool { return len(eq) > 0 } + +// sortedKeys returns all keys of this Eq sorted with sort.Strings. +// It is used internally for consistent ordering when generating +// SQL, see https://github.com/go-xorm/builder/issues/10 +func (eq Eq) sortedKeys() []string { + keys := make([]string, 0, len(eq)) + for key := range eq { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} diff --git a/vendor/github.com/go-xorm/builder/cond_like.go b/vendor/github.com/go-xorm/builder/cond_like.go index 9291f12c9..e34202f8b 100644 --- a/vendor/github.com/go-xorm/builder/cond_like.go +++ b/vendor/github.com/go-xorm/builder/cond_like.go @@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error { if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { return err } - // FIXME: if use other regular express, this will be failed. but for compitable, keep this + // FIXME: if use other regular express, this will be failed. but for compatible, keep this if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { w.Append(like[1]) } else { diff --git a/vendor/github.com/go-xorm/builder/cond_neq.go b/vendor/github.com/go-xorm/builder/cond_neq.go index d07b2b18e..3a8f3136d 100644 --- a/vendor/github.com/go-xorm/builder/cond_neq.go +++ b/vendor/github.com/go-xorm/builder/cond_neq.go @@ -4,7 +4,10 @@ package builder -import "fmt" +import ( + "fmt" + "sort" +) // Neq defines not equal conditions type Neq map[string]interface{} @@ -15,7 +18,8 @@ var _ Cond = Neq{} func (neq Neq) WriteTo(w Writer) error { var args = make([]interface{}, 0, len(neq)) var i = 0 - for k, v := range neq { + for _, k := range neq.sortedKeys() { + v := neq[k] switch v.(type) { case []int, []int64, []string, []int32, []int16, []int8: if err := NotIn(k, v).WriteTo(w); err != nil { @@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond { func (neq Neq) IsValid() bool { return len(neq) > 0 } + +// sortedKeys returns all keys of this Neq sorted with sort.Strings. +// It is used internally for consistent ordering when generating +// SQL, see https://github.com/go-xorm/builder/issues/10 +func (neq Neq) sortedKeys() []string { + keys := make([]string, 0, len(neq)) + for key := range neq { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} diff --git a/vendor/github.com/go-xorm/builder/cond_not.go b/vendor/github.com/go-xorm/builder/cond_not.go index 294a7e072..667dfe729 100644 --- a/vendor/github.com/go-xorm/builder/cond_not.go +++ b/vendor/github.com/go-xorm/builder/cond_not.go @@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error { if _, err := fmt.Fprint(w, "("); err != nil { return err } + case Eq: + if len(not[0].(Eq)) > 1 { + if _, err := fmt.Fprint(w, "("); err != nil { + return err + } + } + case Neq: + if len(not[0].(Neq)) > 1 { + if _, err := fmt.Fprint(w, "("); err != nil { + return err + } + } } if err := not[0].WriteTo(w); err != nil { @@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error { if _, err := fmt.Fprint(w, ")"); err != nil { return err } + case Eq: + if len(not[0].(Eq)) > 1 { + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + } + case Neq: + if len(not[0].(Neq)) > 1 { + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + } } return nil diff --git a/vendor/github.com/go-xorm/builder/strings_builder.go b/vendor/github.com/go-xorm/builder/strings_builder.go new file mode 100644 index 000000000..d4de8717e --- /dev/null +++ b/vendor/github.com/go-xorm/builder/strings_builder.go @@ -0,0 +1,119 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "unicode/utf8" + "unsafe" +) + +// A StringBuilder is used to efficiently build a string using Write methods. +// It minimizes memory copying. The zero value is ready to use. +// Do not copy a non-zero Builder. +type StringBuilder struct { + addr *StringBuilder // of receiver, to detect copies by value + buf []byte +} + +// noescape hides a pointer from escape analysis. noescape is +// the identity function but escape analysis doesn't think the +// output depends on the input. noescape is inlined and currently +// compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime; see issues 23382 and 7921. +//go:nosplit +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +func (b *StringBuilder) copyCheck() { + if b.addr == nil { + // This hack works around a failing of Go's escape analysis + // that was causing b to escape and be heap allocated. + // See issue 23382. + // TODO: once issue 7921 is fixed, this should be reverted to + // just "b.addr = b". + b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b))) + } else if b.addr != b { + panic("strings: illegal use of non-zero Builder copied by value") + } +} + +// String returns the accumulated string. +func (b *StringBuilder) String() string { + return *(*string)(unsafe.Pointer(&b.buf)) +} + +// Len returns the number of accumulated bytes; b.Len() == len(b.String()). +func (b *StringBuilder) Len() int { return len(b.buf) } + +// Reset resets the Builder to be empty. +func (b *StringBuilder) Reset() { + b.addr = nil + b.buf = nil +} + +// grow copies the buffer to a new, larger buffer so that there are at least n +// bytes of capacity beyond len(b.buf). +func (b *StringBuilder) grow(n int) { + buf := make([]byte, len(b.buf), 2*cap(b.buf)+n) + copy(buf, b.buf) + b.buf = buf +} + +// Grow grows b's capacity, if necessary, to guarantee space for +// another n bytes. After Grow(n), at least n bytes can be written to b +// without another allocation. If n is negative, Grow panics. +func (b *StringBuilder) Grow(n int) { + b.copyCheck() + if n < 0 { + panic("strings.Builder.Grow: negative count") + } + if cap(b.buf)-len(b.buf) < n { + b.grow(n) + } +} + +// Write appends the contents of p to b's buffer. +// Write always returns len(p), nil. +func (b *StringBuilder) Write(p []byte) (int, error) { + b.copyCheck() + b.buf = append(b.buf, p...) + return len(p), nil +} + +// WriteByte appends the byte c to b's buffer. +// The returned error is always nil. +func (b *StringBuilder) WriteByte(c byte) error { + b.copyCheck() + b.buf = append(b.buf, c) + return nil +} + +// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer. +// It returns the length of r and a nil error. +func (b *StringBuilder) WriteRune(r rune) (int, error) { + b.copyCheck() + if r < utf8.RuneSelf { + b.buf = append(b.buf, byte(r)) + return 1, nil + } + l := len(b.buf) + if cap(b.buf)-l < utf8.UTFMax { + b.grow(utf8.UTFMax) + } + n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r) + b.buf = b.buf[:l+n] + return n, nil +} + +// WriteString appends the contents of s to b's buffer. +// It returns the length of s and a nil error. +func (b *StringBuilder) WriteString(s string) (int, error) { + b.copyCheck() + b.buf = append(b.buf, s...) + return len(s), nil +} diff --git a/vendor/github.com/go-xorm/core/column.go b/vendor/github.com/go-xorm/core/column.go index 65370bb5b..20912b713 100644 --- a/vendor/github.com/go-xorm/core/column.go +++ b/vendor/github.com/go-xorm/core/column.go @@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { } fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + return nil, fmt.Errorf("field %v is not valid", col.FieldName) } } if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + return nil, fmt.Errorf("field %v is not valid", col.FieldName) } return &fieldValue, nil diff --git a/vendor/github.com/go-xorm/core/db.go b/vendor/github.com/go-xorm/core/db.go index 6111c4b33..9969fa431 100644 --- a/vendor/github.com/go-xorm/core/db.go +++ b/vendor/github.com/go-xorm/core/db.go @@ -7,6 +7,11 @@ import ( "fmt" "reflect" "regexp" + "sync" +) + +var ( + DefaultCacheSize = 200 ) func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { @@ -58,9 +63,16 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error) return query, args, nil } +type cacheStruct struct { + value reflect.Value + idx int +} + type DB struct { *sql.DB - Mapper IMapper + Mapper IMapper + reflectCache map[reflect.Type]*cacheStruct + reflectCacheMutex sync.RWMutex } func Open(driverName, dataSourceName string) (*DB, error) { @@ -68,11 +80,32 @@ func Open(driverName, dataSourceName string) (*DB, error) { if err != nil { return nil, err } - return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil + return &DB{ + DB: db, + Mapper: NewCacheMapper(&SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + }, nil } func FromDB(db *sql.DB) *DB { - return &DB{db, NewCacheMapper(&SnakeMapper{})} + return &DB{ + DB: db, + Mapper: NewCacheMapper(&SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + } +} + +func (db *DB) reflectNew(typ reflect.Type) reflect.Value { + db.reflectCacheMutex.Lock() + defer db.reflectCacheMutex.Unlock() + cs, ok := db.reflectCache[typ] + if !ok || cs.idx+1 > DefaultCacheSize-1 { + cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} + db.reflectCache[typ] = cs + } else { + cs.idx = cs.idx + 1 + } + return cs.value.Index(cs.idx).Addr() } func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { @@ -83,7 +116,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { } return nil, err } - return &Rows{rows, db.Mapper}, nil + return &Rows{rows, db}, nil } func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { @@ -128,8 +161,8 @@ func (db *DB) QueryRowStruct(query string, st interface{}) *Row { type Stmt struct { *sql.Stmt - Mapper IMapper - names map[string]int + db *DB + names map[string]int } func (db *DB) Prepare(query string) (*Stmt, error) { @@ -145,7 +178,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) { if err != nil { return nil, err } - return &Stmt{stmt, db.Mapper, names}, nil + return &Stmt{stmt, db, names}, nil } func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { @@ -179,7 +212,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { if err != nil { return nil, err } - return &Rows{rows, s.Mapper}, nil + return &Rows{rows, s.db}, nil } func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { @@ -274,7 +307,7 @@ func (EmptyScanner) Scan(src interface{}) error { type Tx struct { *sql.Tx - Mapper IMapper + db *DB } func (db *DB) Begin() (*Tx, error) { @@ -282,7 +315,7 @@ func (db *DB) Begin() (*Tx, error) { if err != nil { return nil, err } - return &Tx{tx, db.Mapper}, nil + return &Tx{tx, db}, nil } func (tx *Tx) Prepare(query string) (*Stmt, error) { @@ -298,7 +331,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { if err != nil { return nil, err } - return &Stmt{stmt, tx.Mapper, names}, nil + return &Stmt{stmt, tx.db, names}, nil } func (tx *Tx) Stmt(stmt *Stmt) *Stmt { @@ -327,7 +360,7 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { if err != nil { return nil, err } - return &Rows{rows, tx.Mapper}, nil + return &Rows{rows, tx.db}, nil } func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { diff --git a/vendor/github.com/go-xorm/core/dialect.go b/vendor/github.com/go-xorm/core/dialect.go index 6f2e81d01..c288a0847 100644 --- a/vendor/github.com/go-xorm/core/dialect.go +++ b/vendor/github.com/go-xorm/core/dialect.go @@ -74,6 +74,7 @@ type Dialect interface { GetIndexes(tableName string) (map[string]*Index, error) Filters() []Filter + SetParams(params map[string]string) } func OpenDialect(dialect Dialect) (*DB, error) { @@ -148,7 +149,8 @@ func (db *Base) SupportDropIfExists() bool { } func (db *Base) DropTableSql(tableName string) string { - return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) + quote := db.dialect.Quote + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) } func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { @@ -289,6 +291,9 @@ func (b *Base) LogSQL(sql string, args []interface{}) { } } +func (b *Base) SetParams(params map[string]string) { +} + var ( dialects = map[string]func() Dialect{} ) diff --git a/vendor/github.com/go-xorm/core/filter.go b/vendor/github.com/go-xorm/core/filter.go index 60caaf290..35b0ece67 100644 --- a/vendor/github.com/go-xorm/core/filter.go +++ b/vendor/github.com/go-xorm/core/filter.go @@ -37,9 +37,9 @@ func (q *Quoter) Quote(content string) string { func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { quoter := NewQuoter(dialect) if table != nil && len(table.PrimaryKeys) == 1 { - sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1) - sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1) - return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1) + sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) + sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) + return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) } return sql } diff --git a/vendor/github.com/go-xorm/core/index.go b/vendor/github.com/go-xorm/core/index.go index 73b95175a..9aa1b7ac9 100644 --- a/vendor/github.com/go-xorm/core/index.go +++ b/vendor/github.com/go-xorm/core/index.go @@ -22,6 +22,8 @@ type Index struct { func (index *Index) XName(tableName string) string { if !strings.HasPrefix(index.Name, "UQE_") && !strings.HasPrefix(index.Name, "IDX_") { + tableName = strings.Replace(tableName, `"`, "", -1) + tableName = strings.Replace(tableName, `.`, "_", -1) if index.Type == UniqueType { return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } diff --git a/vendor/github.com/go-xorm/core/rows.go b/vendor/github.com/go-xorm/core/rows.go index 4a4acaa4c..580de4f9c 100644 --- a/vendor/github.com/go-xorm/core/rows.go +++ b/vendor/github.com/go-xorm/core/rows.go @@ -9,7 +9,7 @@ import ( type Rows struct { *sql.Rows - Mapper IMapper + db *DB } func (rs *Rows) ToMapString() ([]map[string]string, error) { @@ -105,7 +105,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { newDest := make([]interface{}, len(cols)) var v EmptyScanner for j, name := range cols { - f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name)) + f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) if f.IsValid() { newDest[j] = f.Addr().Interface() } else { @@ -116,36 +116,6 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { return rs.Rows.Scan(newDest...) } -type cacheStruct struct { - value reflect.Value - idx int -} - -var ( - reflectCache = make(map[reflect.Type]*cacheStruct) - reflectCacheMutex sync.RWMutex -) - -func ReflectNew(typ reflect.Type) reflect.Value { - reflectCacheMutex.RLock() - cs, ok := reflectCache[typ] - reflectCacheMutex.RUnlock() - - const newSize = 200 - - if !ok || cs.idx+1 > newSize-1 { - cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0} - reflectCacheMutex.Lock() - reflectCache[typ] = cs - reflectCacheMutex.Unlock() - } else { - reflectCacheMutex.Lock() - cs.idx = cs.idx + 1 - reflectCacheMutex.Unlock() - } - return cs.value.Index(cs.idx).Addr() -} - // scan data to a slice's pointer, slice's length should equal to columns' number func (rs *Rows) ScanSlice(dest interface{}) error { vv := reflect.ValueOf(dest) @@ -197,9 +167,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { vvv := vv.Elem() for i, _ := range cols { - newDest[i] = ReflectNew(vvv.Type().Elem()).Interface() - //v := reflect.New(vvv.Type().Elem()) - //newDest[i] = v.Interface() + newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() } err = rs.Rows.Scan(newDest...) @@ -215,32 +183,6 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil } -/*func (rs *Rows) ScanMap(dest interface{}) error { - vv := reflect.ValueOf(dest) - if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return errors.New("dest should be a map's pointer") - } - - cols, err := rs.Columns() - if err != nil { - return err - } - - newDest := make([]interface{}, len(cols)) - err = rs.ScanSlice(newDest) - if err != nil { - return err - } - - vvv := vv.Elem() - - for i, name := range cols { - vname := reflect.ValueOf(name) - vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) - } - - return nil -}*/ type Row struct { rows *Rows // One of these two will be non-nil: diff --git a/vendor/github.com/go-xorm/core/table.go b/vendor/github.com/go-xorm/core/table.go index 88199bedd..b5d079404 100644 --- a/vendor/github.com/go-xorm/core/table.go +++ b/vendor/github.com/go-xorm/core/table.go @@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table { } func (table *Table) columnsByName(name string) []*Column { - n := len(name) for k := range table.columnsMap { @@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column { } func (table *Table) GetColumnIdx(name string, idx int) *Column { - cols := table.columnsByName(name) if cols != nil && idx < len(cols) { diff --git a/vendor/github.com/go-xorm/core/type.go b/vendor/github.com/go-xorm/core/type.go index 8010a2220..5cbf93057 100644 --- a/vendor/github.com/go-xorm/core/type.go +++ b/vendor/github.com/go-xorm/core/type.go @@ -69,15 +69,18 @@ var ( Enum = "ENUM" Set = "SET" - Char = "CHAR" - Varchar = "VARCHAR" - NVarchar = "NVARCHAR" - TinyText = "TINYTEXT" - Text = "TEXT" - Clob = "CLOB" - MediumText = "MEDIUMTEXT" - LongText = "LONGTEXT" - Uuid = "UUID" + Char = "CHAR" + Varchar = "VARCHAR" + NVarchar = "NVARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + NText = "NTEXT" + Clob = "CLOB" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" + Uuid = "UUID" + UniqueIdentifier = "UNIQUEIDENTIFIER" + SysName = "SYSNAME" Date = "DATE" DateTime = "DATETIME" @@ -128,10 +131,12 @@ var ( NVarchar: TEXT_TYPE, TinyText: TEXT_TYPE, Text: TEXT_TYPE, + NText: TEXT_TYPE, MediumText: TEXT_TYPE, LongText: TEXT_TYPE, Uuid: TEXT_TYPE, Clob: TEXT_TYPE, + SysName: TEXT_TYPE, Date: TIME_TYPE, DateTime: TIME_TYPE, @@ -148,11 +153,12 @@ var ( Binary: BLOB_TYPE, VarBinary: BLOB_TYPE, - TinyBlob: BLOB_TYPE, - Blob: BLOB_TYPE, - MediumBlob: BLOB_TYPE, - LongBlob: BLOB_TYPE, - Bytea: BLOB_TYPE, + TinyBlob: BLOB_TYPE, + Blob: BLOB_TYPE, + MediumBlob: BLOB_TYPE, + LongBlob: BLOB_TYPE, + Bytea: BLOB_TYPE, + UniqueIdentifier: BLOB_TYPE, Bool: NUMERIC_TYPE, @@ -289,9 +295,9 @@ func SQLType2Type(st SQLType) reflect.Type { return reflect.TypeOf(float32(1)) case Double: return reflect.TypeOf(float64(1)) - case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob: + case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: return reflect.TypeOf("") - case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: return reflect.TypeOf([]byte{}) case Bool: return reflect.TypeOf(true) diff --git a/vendor/github.com/go-xorm/xorm/dialect_mysql.go b/vendor/github.com/go-xorm/xorm/dialect_mysql.go index 99100b232..f2b4ff7a7 100644 --- a/vendor/github.com/go-xorm/xorm/dialect_mysql.go +++ b/vendor/github.com/go-xorm/xorm/dialect_mysql.go @@ -172,12 +172,33 @@ type mysql struct { allowAllFiles bool allowOldPasswords bool clientFoundRows bool + rowFormat string } func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(d, db, uri, drivername, dataSourceName) } +func (db *mysql) SetParams(params map[string]string) { + rowFormat, ok := params["rowFormat"] + if ok { + var t = strings.ToUpper(rowFormat) + switch t { + case "COMPACT": + fallthrough + case "REDUNDANT": + fallthrough + case "DYNAMIC": + fallthrough + case "COMPRESSED": + db.rowFormat = t + break + default: + break + } + } +} + func (db *mysql) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { @@ -487,6 +508,59 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { return indexes, nil } +func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + sql += db.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(db) + } else { + sql += col.StringNoPk(db) + } + sql = strings.TrimSpace(sql) + if len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += db.Quote(strings.Join(pkList, db.Quote(","))) + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + if storeEngine != "" { + sql += " ENGINE=" + storeEngine + } + + if len(charset) == 0 { + charset = db.URI().Charset + } else if len(charset) > 0 { + sql += " DEFAULT CHARSET " + charset + } + + if db.rowFormat != "" { + sql += " ROW_FORMAT=" + db.rowFormat + } + return sql +} + func (db *mysql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}} } diff --git a/vendor/github.com/go-xorm/xorm/dialect_postgres.go b/vendor/github.com/go-xorm/xorm/dialect_postgres.go index d6f71acc1..1f74bd312 100644 --- a/vendor/github.com/go-xorm/xorm/dialect_postgres.go +++ b/vendor/github.com/go-xorm/xorm/dialect_postgres.go @@ -769,14 +769,21 @@ var ( DefaultPostgresSchema = "public" ) +const postgresPublicSchema = "public" + type postgres struct { core.Base - schema string } func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - db.schema = DefaultPostgresSchema - return db.Base.Init(d, db, uri, drivername, dataSourceName) + err := db.Base.Init(d, db, uri, drivername, dataSourceName) + if err != nil { + return err + } + if db.Schema == "" { + db.Schema = DefaultPostgresSchema + } + return nil } func (db *postgres) SqlType(c *core.Column) string { @@ -873,32 +880,42 @@ func (db *postgres) IndexOnTable() bool { } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{tableName, idxName} + if len(db.Schema) == 0 { + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args + } + + args := []interface{}{db.Schema, tableName, idxName} return `SELECT indexname FROM pg_indexes ` + - `WHERE tablename = ? AND indexname = ?`, args + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + if len(db.Schema) == 0 { + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + } + + args := []interface{}{db.Schema, tableName} + return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } -/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { - return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + if len(db.Schema) == 0 { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) + } + return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", + db.Schema, tableName, col.Name, db.SqlType(col)) } func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - //var unique string quote := db.Quote idxName := index.Name + tableName = strings.Replace(tableName, `"`, "", -1) + tableName = strings.Replace(tableName, `.`, "_", -1) + if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { if index.Type == core.UniqueType { @@ -907,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } + if db.Uri.Schema != "" { + idxName = db.Uri.Schema + "." + idxName + } return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{tableName, colName} - query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + - " AND column_name = $2" + args := []interface{}{db.Schema, tableName, colName} + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + + " AND column_name = $3" + if len(db.Schema) == 0 { + args = []interface{}{tableName, colName} + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + + " AND column_name = $2" + } db.LogSQL(query, args) rows, err := db.DB().Query(query, args...) @@ -926,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - // FIXME: the schema should be replaced by user custom's - args := []interface{}{tableName, db.schema} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -938,7 +962,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + var f string + if len(db.Schema) != 0 { + args = append(args, db.Schema) + f = " AND s.table_schema = $2" + } + s = fmt.Sprintf(s, f) + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1028,8 +1060,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att } func (db *postgres) GetTables() ([]*core.Table, error) { - args := []interface{}{db.schema} - s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") + args := []interface{}{} + s := "SELECT tablename FROM pg_tables" + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " WHERE schemaname = $1" + } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1053,8 +1090,12 @@ func (db *postgres) GetTables() ([]*core.Table, error) { } func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { - args := []interface{}{db.schema, tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") + args := []interface{}{tableName} + s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " AND schemaname=$2" + } db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1182,3 +1223,15 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { return db, nil } + +type pqDriverPgx struct { + pqDriver +} + +func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { + // Remove the leading characters for driver to work + if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { + dataSourceName = dataSourceName[9:] + } + return pgx.pqDriver.Parse(driverName, dataSourceName) +} diff --git a/vendor/github.com/go-xorm/xorm/engine.go b/vendor/github.com/go-xorm/xorm/engine.go index 444611afb..846889c36 100644 --- a/vendor/github.com/go-xorm/xorm/engine.go +++ b/vendor/github.com/go-xorm/xorm/engine.go @@ -49,6 +49,35 @@ type Engine struct { tagHandlers map[string]tagHandler engineGroup *EngineGroup + + cachers map[string]core.Cacher + cacherLock sync.RWMutex +} + +func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { + engine.cacherLock.Lock() + engine.cachers[tableName] = cacher + engine.cacherLock.Unlock() +} + +func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { + engine.setCacher(tableName, cacher) +} + +func (engine *Engine) getCacher(tableName string) core.Cacher { + var cacher core.Cacher + var ok bool + engine.cacherLock.RLock() + cacher, ok = engine.cachers[tableName] + engine.cacherLock.RUnlock() + if !ok && !engine.disableGlobalCache { + cacher = engine.Cacher + } + return cacher +} + +func (engine *Engine) GetCacher(tableName string) core.Cacher { + return engine.getCacher(tableName) } // BufferSize sets buffer size for iterate @@ -165,7 +194,7 @@ func (engine *Engine) Quote(value string) string { } // QuoteTo quotes string and writes into the buffer -func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { +func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { if buf == nil { return } @@ -245,13 +274,7 @@ func (engine *Engine) NoCascade() *Session { // MapCacher Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { - v := rValue(bean) - tb, err := engine.autoMapType(v) - if err != nil { - return err - } - - tb.Cacher = cacher + engine.setCacher(engine.TableName(bean, true), cacher) return nil } @@ -536,33 +559,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return nil } -func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) { - v := rValue(beanOrTableName) - if v.Type().Kind() == reflect.String { - return beanOrTableName.(string), nil - } else if v.Type().Kind() == reflect.Struct { - return engine.tbName(v), nil - } - return "", errors.New("bean should be a struct or struct's point") -} - -func (engine *Engine) tbName(v reflect.Value) string { - if tb, ok := v.Interface().(TableName); ok { - return tb.TableName() - } - - if v.Type().Kind() == reflect.Ptr { - if tb, ok := reflect.Indirect(v).Interface().(TableName); ok { - return tb.TableName() - } - } else if v.CanAddr() { - if tb, ok := v.Addr().Interface().(TableName); ok { - return tb.TableName() - } - } - return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()) -} - // Cascade use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() @@ -846,7 +842,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table { if err != nil { engine.logger.Error(err) } - return &Table{tb, engine.tbName(v)} + return &Table{tb, engine.TableName(bean)} } func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { @@ -861,15 +857,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i } } -func (engine *Engine) newTable() *core.Table { - table := core.NewEmptyTable() - - if !engine.disableGlobalCache { - table.Cacher = engine.Cacher - } - return table -} - // TableName table name interface to define customerize table name type TableName interface { TableName() string @@ -881,21 +868,9 @@ var ( func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { t := v.Type() - table := engine.newTable() - if tb, ok := v.Interface().(TableName); ok { - table.Name = tb.TableName() - } else { - if v.CanAddr() { - if tb, ok = v.Addr().Interface().(TableName); ok { - table.Name = tb.TableName() - } - } - if table.Name == "" { - table.Name = engine.TableMapper.Obj2Table(t.Name()) - } - } - + table := core.NewEmptyTable() table.Type = t + table.Name = engine.tbNameForMap(v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -1049,15 +1024,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { if hasCacheTag { if engine.Cacher != nil { // !nash! use engine's cacher if provided engine.logger.Info("enable cache on table:", table.Name) - table.Cacher = engine.Cacher + engine.setCacher(table.Name, engine.Cacher) } else { engine.logger.Info("enable LRU cache on table:", table.Name) - table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now + engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) } } if hasNoCacheTag { - engine.logger.Info("no cache on table:", table.Name) - table.Cacher = nil + engine.logger.Info("disable cache on table:", table.Name) + engine.setCacher(table.Name, nil) } return table, nil @@ -1116,7 +1091,25 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { pk := make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { var err error - pkField := v.FieldByName(col.FieldName) + + fieldName := col.FieldName + for { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) == 1 { + break + } + + v = v.FieldByName(parts[0]) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil, ErrUnSupportedType + } + fieldName = parts[1] + } + + pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = engine.idTypeAssertion(col, pkField.String()) @@ -1162,26 +1155,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error { return session.CreateUniques(bean) } -func (engine *Engine) getCacher2(table *core.Table) core.Cacher { - return table.Cacher -} - // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - v := rValue(bean) - t := v.Type() - if t.Kind() != reflect.Struct { - return errors.New("error params") - } - tableName := engine.tbName(v) - table, err := engine.autoMapType(v) - if err != nil { - return err - } - cacher := table.Cacher - if cacher == nil { - cacher = engine.Cacher - } + tableName := engine.TableName(bean) + cacher := engine.getCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.DelBean(tableName, id) @@ -1192,21 +1169,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - v := rValue(bean) - t := v.Type() - if t.Kind() != reflect.Struct { - return errors.New("error params") - } - tableName := engine.tbName(v) - table, err := engine.autoMapType(v) - if err != nil { - return err - } - - cacher := table.Cacher - if cacher == nil { - cacher = engine.Cacher - } + tableName := engine.TableName(bean) + cacher := engine.getCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.ClearBeans(tableName) @@ -1224,13 +1188,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - tableName := engine.tbName(v) + tableNameNoSchema := engine.TableName(bean) table, err := engine.autoMapType(v) if err != nil { return err } - isExist, err := session.Table(bean).isTableExist(tableName) + isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) if err != nil { return err } @@ -1256,12 +1220,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) + isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -1272,35 +1236,35 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } if index.Type == core.UniqueType { - isExist, err := session.isIndexExist2(tableName, index.Cols, true) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } - err = session.addUnique(tableName, name) + err = session.addUnique(tableNameNoSchema, name) if err != nil { return err } } } else if index.Type == core.IndexType { - isExist, err := session.isIndexExist2(tableName, index.Cols, false) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) if err != nil { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } - err = session.addIndex(tableName, name) + err = session.addIndex(tableNameNoSchema, name) if err != nil { return err } @@ -1453,6 +1417,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { return session.Find(beans, condiBeans...) } +// FindAndCount find the results and also return the counts +func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.FindAndCount(rowsSlicePtr, condiBean...) +} + // Iterate record by record handle records from table, bean's non-empty fields // are conditions. func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { @@ -1629,6 +1600,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { engine.DatabaseTZ = tz } +// SetSchema sets the schema of database +func (engine *Engine) SetSchema(schema string) { + engine.dialect.URI().Schema = schema +} + // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() diff --git a/vendor/github.com/go-xorm/xorm/engine_cond.go b/vendor/github.com/go-xorm/xorm/engine_cond.go index 6c8e3879c..4dde8662e 100644 --- a/vendor/github.com/go-xorm/xorm/engine_cond.go +++ b/vendor/github.com/go-xorm/xorm/engine_cond.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "time" "github.com/go-xorm/builder" @@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, fieldValuePtr, err := col.ValueOf(bean) if err != nil { - engine.logger.Error(err) + if !strings.Contains(err.Error(), "is not valid") { + engine.logger.Warn(err) + } continue } diff --git a/vendor/github.com/go-xorm/xorm/engine_table.go b/vendor/github.com/go-xorm/xorm/engine_table.go new file mode 100644 index 000000000..94871a4bc --- /dev/null +++ b/vendor/github.com/go-xorm/xorm/engine_table.go @@ -0,0 +1,113 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/go-xorm/core" +) + +// TableNameWithSchema will automatically add schema prefix on table name +func (engine *Engine) tbNameWithSchema(v string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if engine.dialect.DBType() == core.POSTGRES && + engine.dialect.URI().Schema != "" && + engine.dialect.URI().Schema != postgresPublicSchema && + strings.Index(v, ".") == -1 { + return engine.dialect.URI().Schema + "." + v + } + return v +} + +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + tbName := engine.tbNameNoSchema(bean) + if len(includeSchema) > 0 && includeSchema[0] { + tbName = engine.tbNameWithSchema(tbName) + } + + return tbName +} + +// tbName get some table's table name +func (session *Session) tbNameNoSchema(table *core.Table) string { + if len(session.statement.AltTableName) > 0 { + return session.statement.AltTableName + } + + return table.Name +} + +func (engine *Engine) tbNameForMap(v reflect.Value) string { + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + } + + return engine.TableMapper.Obj2Table(v.Type().Name()) +} + +func (engine *Engine) tbNameNoSchema(tablename interface{}) string { + switch tablename.(type) { + case []string: + t := tablename.([]string) + if len(t) > 1 { + return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + } else if len(t) == 1 { + return engine.Quote(t[0]) + } + case []interface{}: + t := tablename.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case TableName: + table = f.(TableName).TableName() + default: + v := rValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + table = engine.tbNameForMap(v) + } else { + table = engine.Quote(fmt.Sprintf("%v", f)) + } + } + } + if l > 1 { + return fmt.Sprintf("%v AS %v", engine.Quote(table), + engine.Quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + return engine.Quote(table) + } + case TableName: + return tablename.(TableName).TableName() + case string: + return tablename.(string) + case reflect.Value: + v := tablename.(reflect.Value) + return engine.tbNameForMap(v) + default: + v := rValue(tablename) + t := v.Type() + if t.Kind() == reflect.Struct { + return engine.tbNameForMap(v) + } + return engine.Quote(fmt.Sprintf("%v", tablename)) + } + return "" +} diff --git a/vendor/github.com/go-xorm/xorm/error.go b/vendor/github.com/go-xorm/xorm/error.go index cfeefc31e..a223fc4a8 100644 --- a/vendor/github.com/go-xorm/xorm/error.go +++ b/vendor/github.com/go-xorm/xorm/error.go @@ -6,23 +6,44 @@ package xorm import ( "errors" + "fmt" ) var ( // ErrParamsType params error ErrParamsType = errors.New("Params type error") // ErrTableNotFound table not found error - ErrTableNotFound = errors.New("Not found table") + ErrTableNotFound = errors.New("Table not found") // ErrUnSupportedType unsupported error ErrUnSupportedType = errors.New("Unsupported type error") - // ErrNotExist record is not exist error - ErrNotExist = errors.New("Not exist error") + // ErrNotExist record does not exist error + ErrNotExist = errors.New("Record does not exist") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete need at least one condition") + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported - ErrConditionType = errors.New("Unsupported conditon type") + ErrConditionType = errors.New("Unsupported condition type") ) + +// ErrFieldIsNotExist columns does not exist +type ErrFieldIsNotExist struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotExist) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} + +// ErrFieldIsNotValid is not valid +type ErrFieldIsNotValid struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotValid) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} diff --git a/vendor/github.com/go-xorm/xorm/helpers.go b/vendor/github.com/go-xorm/xorm/helpers.go index f39ed4725..f1705782e 100644 --- a/vendor/github.com/go-xorm/xorm/helpers.go +++ b/vendor/github.com/go-xorm/xorm/helpers.go @@ -11,7 +11,6 @@ import ( "sort" "strconv" "strings" - "time" "github.com/go-xorm/core" ) @@ -293,19 +292,6 @@ func structName(v reflect.Type) string { return v.Name() } -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - func sliceEq(left, right []string) bool { if len(left) != len(right) { return false @@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool { return true } -func setColumnInt(bean interface{}, col *core.Column, t int64) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t)) - } - } -} - -func setColumnTime(bean interface{}, col *core.Column, t time.Time) { - v, err := col.ValueOf(bean) - if err != nil { - return - } - if v.CanSet() { - switch v.Type().Kind() { - case reflect.Struct: - v.Set(reflect.ValueOf(t).Convert(v.Type())) - case reflect.Int, reflect.Int64, reflect.Int32: - v.SetInt(t.Unix()) - case reflect.Uint, reflect.Uint64, reflect.Uint32: - v.SetUint(uint64(t.Unix())) - } - } -} - -func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { - colNames := make([]string, 0, len(table.ColumnsSeq())) - args := make([]interface{}, 0, len(table.ColumnsSeq())) - - for _, col := range table.Columns() { - if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } - } - if col.MapType == core.ONLYFROMDB { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - return nil, nil, err - } - fieldValue := *fieldValuePtr - - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } - } - - if col.IsDeleted { - continue - } - - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } else if _, ok := session.statement.incrColumns[col.Name]; ok { - continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { - continue - } - } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - } - } - - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { - // if time is non-empty, then set to auto time - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return colNames, args, err - } - args = append(args, arg) - } - - if includeQuote { - colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") - } else { - colNames = append(colNames, col.Name) - } - } - return colNames, args, nil -} - func indexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } - -func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { - if len(m) == 0 { - return false, false - } - - n := len(col.Name) - - for mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, col.Name) { - return m[mk], true - } - } - - return false, false -} diff --git a/vendor/github.com/go-xorm/xorm/interface.go b/vendor/github.com/go-xorm/xorm/interface.go index 9a3b6da0b..0bc12ba00 100644 --- a/vendor/github.com/go-xorm/xorm/interface.go +++ b/vendor/github.com/go-xorm/xorm/interface.go @@ -30,6 +30,7 @@ type Interface interface { Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error + FindAndCount(interface{}, ...interface{}) (int64, error) Get(interface{}) (bool, error) GroupBy(keys string) *Session ID(interface{}) *Session @@ -41,6 +42,7 @@ type Interface interface { IsTableExist(beanOrTableName interface{}) (bool, error) Iterate(interface{}, IterFunc) error Limit(int, ...int) *Session + MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session @@ -75,6 +77,7 @@ type EngineInterface interface { Dialect() core.Dialect DropTables(...interface{}) error DumpAllToFile(fp string, tp ...core.DbType) error + GetCacher(string) core.Cacher GetColumnMapper() core.IMapper GetDefaultCacher() core.Cacher GetTableMapper() core.IMapper @@ -83,9 +86,11 @@ type EngineInterface interface { NewSession() *Session NoAutoTime() *Session Quote(string) string + SetCacher(string, core.Cacher) SetDefaultCacher(core.Cacher) SetLogLevel(core.LogLevel) SetMapper(core.IMapper) + SetSchema(string) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) ShowSQL(show ...bool) @@ -93,6 +98,7 @@ type EngineInterface interface { Sync2(...interface{}) error StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) *Table + TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/vendor/github.com/go-xorm/xorm/rows.go b/vendor/github.com/go-xorm/xorm/rows.go index 31e29ae26..54ec7f37a 100644 --- a/vendor/github.com/go-xorm/xorm/rows.go +++ b/vendor/github.com/go-xorm/xorm/rows.go @@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { + if err = rows.session.statement.setRefBean(bean); err != nil { return nil, err } @@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - dataStruct := rValue(bean) - if err := rows.session.statement.setRefValue(dataStruct); err != nil { + if err := rows.session.statement.setRefBean(bean); err != nil { return err } @@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } + dataStruct := rValue(bean) _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/vendor/github.com/go-xorm/xorm/session.go b/vendor/github.com/go-xorm/xorm/session.go index 5c6cb5f9d..3775eb011 100644 --- a/vendor/github.com/go-xorm/xorm/session.go +++ b/vendor/github.com/go-xorm/xorm/session.go @@ -278,24 +278,22 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { +func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { var col *core.Column if col = table.GetColumnIdx(key, idx); col == nil { - //session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) - return nil + return nil, ErrFieldIsNotExist{key, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - session.engine.logger.Error(err) - return nil + return nil, err } if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) - return nil + return nil, ErrFieldIsNotValid{key, table.Name} } - return fieldValue + + return fieldValue, nil } // Cell cell is a result of one column field @@ -407,409 +405,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } tempMap[lKey] = idx - if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil { - rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - - // if row is null then ignore - if rawValue.Interface() == nil { - continue + fieldValue, err := session.getField(dataStruct, key, table, idx) + if err != nil { + if !strings.Contains(err.Error(), "is not valid") { + session.engine.logger.Warn(err) } + continue + } + if fieldValue == nil { + continue + } + rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if err := structConvert.FromDB(data); err != nil { - return nil, err - } - } else { + // if row is null then ignore + if rawValue.Interface() == nil { + continue + } + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + if err := structConvert.FromDB(data); err != nil { return nil, err } - continue - } - } - - if _, ok := fieldValue.Interface().(core.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue.Interface().(core.Conversion).FromDB(data) } else { return nil, err } continue } + } - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - col := table.GetColumnIdx(key, idx) - if col.IsPrimaryKey { - pk = append(pk, rawValue.Interface()) + if _, ok := fieldValue.Interface().(core.Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue.Interface().(core.Conversion).FromDB(data) + } else { + return nil, err } - fieldType := fieldValue.Type() - hasAssigned := false + continue + } - if col.SQLType.IsJson() { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { - bs = vv.Bytes() + rawValueType := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + col := table.GetColumnIdx(key, idx) + if col.IsPrimaryKey { + pk = append(pk, rawValue.Interface()) + } + fieldType := fieldValue.Type() + hasAssigned := false + + if col.SQLType.IsJson() { + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(core.BytesType) { + bs = vv.Bytes() + } else { + return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) + } + + hasAssigned = true + + if len(bs) > 0 { + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + continue + } + if fieldValue.CanAddr() { + err := json.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return nil, err + } } else { - return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) - } - - hasAssigned = true - - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - continue - } - if fieldValue.CanAddr() { - err := json.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) + x := reflect.New(fieldType) + err := json.Unmarshal(bs, x.Interface()) + if err != nil { + return nil, err } + fieldValue.Set(x.Elem()) } - - continue } - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { - bs = vv.Bytes() - } + continue + } - hasAssigned = true - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) + switch fieldType.Kind() { + case reflect.Complex64, reflect.Complex128: + // TODO: reimplement this + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(core.BytesType) { + bs = vv.Bytes() + } + + hasAssigned = true + if len(bs) > 0 { + if fieldValue.CanAddr() { + err := json.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return nil, err } + } else { + x := reflect.New(fieldType) + err := json.Unmarshal(bs, x.Interface()) + if err != nil { + return nil, err + } + fieldValue.Set(x.Elem()) } + } + case reflect.Slice, reflect.Array: + switch rawValueType.Kind() { case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } - } - } - } - } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - hasAssigned = true - fieldValue.SetString(vv.String()) - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - hasAssigned = true - fieldValue.SetBool(vv.Bool()) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetInt(vv.Int()) - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - hasAssigned = true - fieldValue.SetFloat(vv.Float()) - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - hasAssigned = true - fieldValue.SetUint(vv.Uint()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetUint(uint64(vv.Int())) - } - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - dbTZ := session.engine.DatabaseTZ - if col.TimeZone != nil { - dbTZ = col.TimeZone - } - - if rawValueType == core.TimeType { + switch rawValueType.Elem().Kind() { + case reflect.Uint8: + if fieldType.Elem().Kind() == reflect.Uint8 { hasAssigned = true - - t := vv.Convert(core.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == core.IntType || rawValueType == core.Int64Type || - rawValueType == core.Int32Type { - hasAssigned = true - - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else { - if d, ok := vv.Interface().([]uint8); ok { - hasAssigned = true - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else if d, ok := vv.Interface().(string); ok { - hasAssigned = true - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else { - return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } - } - } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - hasAssigned = true - if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) - hasAssigned = false - } - } else if col.SQLType.IsJson() { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) err := json.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } fieldValue.Set(x.Elem()) - } - } - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return nil, err - } - - hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(core.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) - if err != nil { - return nil, err - } - - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - fieldValue.Set(structInter.Elem()) } else { - return nil, errors.New("cascade obj is not exist") + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) + } + } + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } + } } } } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case core.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - } // switch fieldType - } // switch fieldType.Kind() + } + case reflect.String: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + fieldValue.SetString(vv.String()) + } + case reflect.Bool: + if rawValueType.Kind() == reflect.Bool { + hasAssigned = true + fieldValue.SetBool(vv.Bool()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rawValueType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetInt(vv.Int()) + } + case reflect.Float32, reflect.Float64: + switch rawValueType.Kind() { + case reflect.Float32, reflect.Float64: + hasAssigned = true + fieldValue.SetFloat(vv.Float()) + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch rawValueType.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + hasAssigned = true + fieldValue.SetUint(vv.Uint()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetUint(uint64(vv.Int())) + } + case reflect.Struct: + if fieldType.ConvertibleTo(core.TimeType) { + dbTZ := session.engine.DatabaseTZ + if col.TimeZone != nil { + dbTZ = col.TimeZone + } - // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value - if !hasAssigned { - data, err := value2Bytes(&rawValue) + if rawValueType == core.TimeType { + hasAssigned = true + + t := vv.Convert(core.TimeType).Interface().(time.Time) + + z, _ := t.Zone() + // set new location if database don't save timezone or give an incorrect timezone + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location + session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) + t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbTZ) + } + + t = t.In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } else if rawValueType == core.IntType || rawValueType == core.Int64Type || + rawValueType == core.Int32Type { + hasAssigned = true + + t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } else { + if d, ok := vv.Interface().([]uint8); ok { + hasAssigned = true + t, err := session.byte2Time(col, d) + if err != nil { + session.engine.logger.Error("byte2Time error:", err.Error()) + hasAssigned = false + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } + } else if d, ok := vv.Interface().(string); ok { + hasAssigned = true + t, err := session.str2Time(col, d) + if err != nil { + session.engine.logger.Error("byte2Time error:", err.Error()) + hasAssigned = false + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } + } else { + return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) + } + } + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + hasAssigned = true + if err := nulVal.Scan(vv.Interface()); err != nil { + session.engine.logger.Error("sql.Sanner error:", err.Error()) + hasAssigned = false + } + } else if col.SQLType.IsJson() { + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + return nil, err + } + fieldValue.Set(x.Elem()) + } + } else if rawValueType.Kind() == reflect.Slice { + hasAssigned = true + x := reflect.New(fieldType) + if len(vv.Bytes()) > 0 { + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + return nil, err + } + fieldValue.Set(x.Elem()) + } + } + } else if session.statement.UseCascade { + table, err := session.engine.autoMapType(*fieldValue) if err != nil { return nil, err } - if err = session.bytes2Value(col, fieldValue, data); err != nil { + hasAssigned = true + if len(table.PrimaryKeys) != 1 { + return nil, errors.New("unsupported non or composited primary key cascade") + } + var pk = make(core.PK, len(table.PrimaryKeys)) + pk[0], err = asKind(vv, rawValueType) + if err != nil { return nil, err } + + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return nil, err + } + if has { + fieldValue.Set(structInter.Elem()) + } else { + return nil, errors.New("cascade obj is not exist") + } + } + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case core.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrTimeType: + if rawValueType == core.PtrTimeType { + hasAssigned = true + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Uint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Uint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.Complex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return nil, err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + case core.Complex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return nil, err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + } // switch fieldType + } // switch fieldType.Kind() + + // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value + if !hasAssigned { + data, err := value2Bytes(&rawValue) + if err != nil { + return nil, err + } + + if err = session.bytes2Value(col, fieldValue, data); err != nil { + return nil, err } } } @@ -828,15 +834,6 @@ func (session *Session) LastSQL() (string, []interface{}) { return session.lastSQL, session.lastSQLArgs } -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { session.statement.Unscoped() diff --git a/vendor/github.com/go-xorm/xorm/session_cols.go b/vendor/github.com/go-xorm/xorm/session_cols.go index 9972cb0ae..47d109c6c 100644 --- a/vendor/github.com/go-xorm/xorm/session_cols.go +++ b/vendor/github.com/go-xorm/xorm/session_cols.go @@ -4,6 +4,121 @@ package xorm +import ( + "reflect" + "strings" + "time" + + "github.com/go-xorm/core" +) + +type incrParam struct { + colName string + arg interface{} +} + +type decrParam struct { + colName string + arg interface{} +} + +type exprParam struct { + colName string + expr string +} + +type columnMap []string + +func (m columnMap) contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func (m *columnMap) add(colName string) bool { + if m.contain(colName) { + return false + } + *m = append(*m, colName) + return true +} + +func setColumnInt(bean interface{}, col *core.Column, t int64) { + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t)) + } + } +} + +func setColumnTime(bean interface{}, col *core.Column, t time.Time) { + v, err := col.ValueOf(bean) + if err != nil { + return + } + if v.CanSet() { + switch v.Type().Kind() { + case reflect.Struct: + v.Set(reflect.ValueOf(t).Convert(v.Type())) + case reflect.Int, reflect.Int64, reflect.Int32: + v.SetInt(t.Unix()) + case reflect.Uint, reflect.Uint64, reflect.Uint32: + v.SetUint(uint64(t.Unix())) + } + } +} + +func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { + if len(m) == 0 { + return false, false + } + + n := len(col.Name) + + for mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, col.Name) { + return m[mk], true + } + } + + return false, false +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + // Incr provides a query string like "count = count + 1" func (session *Session) Incr(column string, arg ...interface{}) *Session { session.statement.Incr(column, arg...) diff --git a/vendor/github.com/go-xorm/xorm/session_delete.go b/vendor/github.com/go-xorm/xorm/session_delete.go index 688b122ca..d9cf3ea93 100644 --- a/vendor/github.com/go-xorm/xorm/session_delete.go +++ b/vendor/github.com/go-xorm/xorm/session_delete.go @@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, return ErrCacheFailed } - cacher := session.engine.getCacher2(table) + cacher := session.engine.getCacher(tableName) pkColumns := table.PKColumns() ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { @@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } @@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } diff --git a/vendor/github.com/go-xorm/xorm/session_exist.go b/vendor/github.com/go-xorm/xorm/session_exist.go index 378a64837..74a660e85 100644 --- a/vendor/github.com/go-xorm/xorm/session_exist.go +++ b/vendor/github.com/go-xorm/xorm/session_exist.go @@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean[0]); err != nil { return false, err } } diff --git a/vendor/github.com/go-xorm/xorm/session_find.go b/vendor/github.com/go-xorm/xorm/session_find.go index f95dcfef2..46bbf26c9 100644 --- a/vendor/github.com/go-xorm/xorm/session_find.go +++ b/vendor/github.com/go-xorm/xorm/session_find.go @@ -29,6 +29,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return session.find(rowsSlicePtr, condiBean...) } +// FindAndCount find the results and also return the counts +func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + if session.isAutoClose { + defer session.Close() + } + + session.autoResetStatement = false + err := session.find(rowsSlicePtr, condiBean...) + if err != nil { + return 0, err + } + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + return 0, errors.New("needs a pointer to a slice or a map") + } + + sliceElementType := sliceValue.Type().Elem() + if sliceElementType.Kind() == reflect.Ptr { + sliceElementType = sliceElementType.Elem() + } + session.autoResetStatement = true + + if session.statement.selectStr != "" { + session.statement.selectStr = "" + } + if session.statement.OrderStr != "" { + session.statement.OrderStr = "" + } + + return session.Count(reflect.New(sliceElementType).Interface()) +} + func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { @@ -42,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv); err != nil { return err } } else { @@ -50,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv); err != nil { return err } } else { @@ -128,7 +161,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) if err != nil { return err } @@ -143,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } if session.canCache() { - if cacher := session.engine.getCacher2(table); cacher != nil && + if cacher := session.engine.getCacher(table.Name); cacher != nil && !session.statement.IsDistinct && !session.statement.unscoped { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) @@ -288,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return ErrCacheFailed } + tableName := session.statement.TableName() + cacher := session.engine.getCacher(tableName) + if cacher == nil { + return nil + } + for _, filter := range session.engine.dialect.Filters() { sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) } @@ -297,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return ErrCacheFailed } - tableName := session.statement.TableName() table := session.statement.RefTable - cacher := session.engine.getCacher2(table) ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { rows, err := session.queryRows(newsql, args...) diff --git a/vendor/github.com/go-xorm/xorm/session_get.go b/vendor/github.com/go-xorm/xorm/session_get.go index 68b37af7f..3b2c9493c 100644 --- a/vendor/github.com/go-xorm/xorm/session_get.go +++ b/vendor/github.com/go-xorm/xorm/session_get.go @@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return false, err } } @@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.engine.getCacher2(table); cacher != nil && + if cacher := session.engine.getCacher(table.Name); cacher != nil && !session.statement.unscoped { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { @@ -134,8 +134,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, ErrCacheFailed } - cacher := session.engine.getCacher2(session.statement.RefTable) tableName := session.statement.TableName() + cacher := session.engine.getCacher(tableName) + session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) table := session.statement.RefTable ids, err := core.GetCacheSql(cacher, tableName, newsql, args) diff --git a/vendor/github.com/go-xorm/xorm/session_insert.go b/vendor/github.com/go-xorm/xorm/session_insert.go index 129ee2309..2ea58fdaf 100644 --- a/vendor/github.com/go-xorm/xorm/session_insert.go +++ b/vendor/github.com/go-xorm/xorm/session_insert.go @@ -66,11 +66,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { + if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + tableName := session.statement.TableName() + if len(tableName) <= 0 { return 0, ErrTableNotFound } @@ -115,15 +116,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } + if session.statement.omitColumnMap.contain(col.Name) { + continue } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { val, t := session.engine.nowTime(col) @@ -170,15 +167,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { - continue - } + if session.statement.omitColumnMap.contain(col.Name) { + continue } - if session.statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { - continue - } + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { val, t := session.engine.nowTime(col) @@ -211,38 +204,33 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } cleanupProcessorsClosures(&session.beforeClosures) - var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" - var statement string - var tableName = session.statement.TableName() + var sql string if session.engine.dialect.DBType() == core.ORACLE { - sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr()) - statement = fmt.Sprintf(sql, + sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr(), strings.Join(colMultiPlaces, temp)) } else { - statement = fmt.Sprintf(sql, + sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr(), strings.Join(colMultiPlaces, "),(")) } - res, err := session.exec(statement, args...) + res, err := session.exec(sql, args...) if err != nil { return 0, err } - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) lenAfterClosures := len(session.afterClosures) for i := 0; i < size; i++ { @@ -298,7 +286,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { @@ -316,8 +304,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { processor.BeforeInsert() } - // -- - colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) + + colNames, args, err := session.genInsertColumns(bean) if err != nil { return 0, err } @@ -402,9 +390,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -447,9 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -490,9 +474,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - session.cacheInsert(table, tableName) - } + session.cacheInsert(tableName) if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -539,16 +521,104 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (session *Session) cacheInsert(table *core.Table, tables ...string) error { - if table == nil { - return ErrCacheFailed +func (session *Session) cacheInsert(table string) error { + if !session.statement.UseCache { + return nil } - - cacher := session.engine.getCacher2(table) - for _, t := range tables { - session.engine.logger.Debug("[cache] clear sql:", t) - cacher.ClearIds(t) + cacher := session.engine.getCacher(table) + if cacher == nil { + return nil } - + session.engine.logger.Debug("[cache] clear sql:", table) + cacher.ClearIds(table) return nil } + +// genInsertColumns generates insert needed columns +func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { + table := session.statement.RefTable + colNames := make([]string, 0, len(table.ColumnsSeq())) + args := make([]interface{}, 0, len(table.ColumnsSeq())) + + for _, col := range table.Columns() { + if col.MapType == core.ONLYFROMDB { + continue + } + + if col.IsDeleted { + continue + } + + if session.statement.omitColumnMap.contain(col.Name) { + continue + } + + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue + } + + if _, ok := session.statement.incrColumns[col.Name]; ok { + continue + } else if _, ok := session.statement.decrColumns[col.Name]; ok { + continue + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + return nil, nil, err + } + fieldValue := *fieldValuePtr + + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + case reflect.Ptr: + if fieldValue.Pointer() == 0 { + continue + } + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { + // if time is non-empty, then set to auto time + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.checkVersion { + args = append(args, 1) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } + + colNames = append(colNames, col.Name) + } + return colNames, args, nil +} diff --git a/vendor/github.com/go-xorm/xorm/session_query.go b/vendor/github.com/go-xorm/xorm/session_query.go index f8098f849..5c9aeb391 100644 --- a/vendor/github.com/go-xorm/xorm/session_query.go +++ b/vendor/github.com/go-xorm/xorm/session_query.go @@ -64,13 +64,17 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa } } + if err := session.statement.processIDParam(); err != nil { + return "", nil, err + } + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) if err != nil { return "", nil, err } args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) if err != nil { return "", nil, err } diff --git a/vendor/github.com/go-xorm/xorm/session_schema.go b/vendor/github.com/go-xorm/xorm/session_schema.go index a2708b736..369ec72a4 100644 --- a/vendor/github.com/go-xorm/xorm/session_schema.go +++ b/vendor/github.com/go-xorm/xorm/session_schema.go @@ -6,9 +6,7 @@ package xorm import ( "database/sql" - "errors" "fmt" - "reflect" "strings" "github.com/go-xorm/core" @@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return err - } - + tableName := session.engine.TableName(beanOrTableName) var needDrop = true if !session.engine.dialect.SupportDropIfExists() { sqlStr, args := session.engine.dialect.TableCheckSql(tableName) @@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(tableName) - _, err = session.exec(sqlStr) + sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) + _, err := session.exec(sqlStr) return err } return nil @@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) defer session.Close() } - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return false, err - } + tableName := session.engine.TableName(beanOrTableName) return session.isTableExist(tableName) } @@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) { // IsTableEmpty if table have any records func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { - v := rValue(bean) - t := v.Type() - - if t.Kind() == reflect.String { - if session.isAutoClose { - defer session.Close() - } - return session.isTableEmpty(bean.(string)) - } else if t.Kind() == reflect.Struct { - rows, err := session.Count(bean) - return rows == 0, err + if session.isAutoClose { + defer session.Close() } - return false, errors.New("bean should be a struct or struct's point") + return session.isTableEmpty(session.engine.TableName(bean)) } func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { @@ -255,6 +233,12 @@ func (session *Session) Sync2(beans ...interface{}) error { return err } + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + var structTables []*core.Table for _, bean := range beans { @@ -264,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error { return err } structTables = append(structTables, table) - var tbName = session.tbNameNoSchema(table) + tbName := engine.TableName(bean) + tbNameWithSchema := engine.TableName(tbName, true) var oriTable *core.Table for _, tb := range tables { @@ -309,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error { if engine.dialect.DBType() == core.MYSQL || engine.dialect.DBType() == core.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", - tbName, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, curType, expectedType) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } } else if expectedType == core.Varchar { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } @@ -348,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } } else { session.statement.RefTable = table - session.statement.tableName = tbName + session.statement.tableName = tbNameWithSchema err = session.addColumn(col.Name) } if err != nil { @@ -371,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSql(tbName, oriIndex) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err @@ -387,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSql(tbName, index2) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err @@ -398,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addUnique(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addUnique(tbNameWithSchema, name) } else if index.Type == core.IndexType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addIndex(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addIndex(tbNameWithSchema, name) } if err != nil { return err @@ -428,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, colName := range table.ColumnsSeq() { if oriTable.GetColumn(colName) == nil { - engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) } } } diff --git a/vendor/github.com/go-xorm/xorm/session_tx.go b/vendor/github.com/go-xorm/xorm/session_tx.go index 84d2f7f9d..c8d759a31 100644 --- a/vendor/github.com/go-xorm/xorm/session_tx.go +++ b/vendor/github.com/go-xorm/xorm/session_tx.go @@ -24,6 +24,7 @@ func (session *Session) Rollback() error { if !session.isAutoCommit && !session.isCommitedOrRollbacked { session.saveLastSQL(session.engine.dialect.RollBackStr()) session.isCommitedOrRollbacked = true + session.isAutoCommit = true return session.tx.Rollback() } return nil @@ -34,6 +35,7 @@ func (session *Session) Commit() error { if !session.isAutoCommit && !session.isCommitedOrRollbacked { session.saveLastSQL("COMMIT") session.isCommitedOrRollbacked = true + session.isAutoCommit = true var err error if err = session.tx.Commit(); err == nil { // handle processors after tx committed diff --git a/vendor/github.com/go-xorm/xorm/session_update.go b/vendor/github.com/go-xorm/xorm/session_update.go index f55874566..84c7e7fec 100644 --- a/vendor/github.com/go-xorm/xorm/session_update.go +++ b/vendor/github.com/go-xorm/xorm/session_update.go @@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } } - cacher := session.engine.getCacher2(table) + cacher := session.engine.getCacher(tableName) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { @@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } @@ -176,12 +176,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if session.statement.ColumnStr == "" { - colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, - false, false, session.statement.allUseBool, session.statement.useAllCols, - session.statement.mustColumnMap, session.statement.nullableMap, - session.statement.columnMap, true, session.statement.unscoped) + colNames, args = session.statement.buildUpdates(bean, false, false, + false, false, true) } else { - colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) + colNames, args, err = session.genUpdateColumns(bean) if err != nil { return 0, err } @@ -202,7 +200,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { + if !session.statement.columnMap.contain(table.Updated) && + !session.statement.omitColumnMap.contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -362,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if table != nil { - if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { - //session.cacheUpdate(table, tableName, sqlStr, args...) - cacher.ClearIds(tableName) - cacher.ClearBeans(tableName) - } + if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { + //session.cacheUpdate(table, tableName, sqlStr, args...) + session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + cacher.ClearIds(tableName) + cacher.ClearBeans(tableName) } // handle after update processors @@ -402,3 +400,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } + +func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) { + table := session.statement.RefTable + colNames := make([]string, 0, len(table.ColumnsSeq())) + args := make([]interface{}, 0, len(table.ColumnsSeq())) + + for _, col := range table.Columns() { + if !col.IsVersion && !col.IsCreated && !col.IsUpdated { + if session.statement.omitColumnMap.contain(col.Name) { + continue + } + } + if col.MapType == core.ONLYFROMDB { + continue + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + return nil, nil, err + } + fieldValue := *fieldValuePtr + + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + case reflect.Ptr: + if fieldValue.Pointer() == 0 { + continue + } + } + } + + if col.IsDeleted || col.IsCreated { + continue + } + + if len(session.statement.columnMap) > 0 { + if !session.statement.columnMap.contain(col.Name) { + continue + } else if _, ok := session.statement.incrColumns[col.Name]; ok { + continue + } else if _, ok := session.statement.decrColumns[col.Name]; ok { + continue + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } + + if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { + // if time is non-empty, then set to auto time + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.checkVersion { + args = append(args, 1) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } + + colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") + } + return colNames, args, nil +} diff --git a/vendor/github.com/go-xorm/xorm/statement.go b/vendor/github.com/go-xorm/xorm/statement.go index 6400425b2..7856936f5 100644 --- a/vendor/github.com/go-xorm/xorm/statement.go +++ b/vendor/github.com/go-xorm/xorm/statement.go @@ -5,7 +5,6 @@ package xorm import ( - "bytes" "database/sql/driver" "encoding/json" "errors" @@ -18,21 +17,6 @@ import ( "github.com/go-xorm/core" ) -type incrParam struct { - colName string - arg interface{} -} - -type decrParam struct { - colName string - arg interface{} -} - -type exprParam struct { - colName string - expr string -} - // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -47,7 +31,6 @@ type Statement struct { HavingStr string ColumnStr string selectStr string - columnMap map[string]bool useAllCols bool OmitStr string AltTableName string @@ -67,6 +50,8 @@ type Statement struct { allUseBool bool checkVersion bool unscoped bool + columnMap columnMap + omitColumnMap columnMap mustColumnMap map[string]bool nullableMap map[string]bool incrColumns map[string]incrParam @@ -89,7 +74,8 @@ func (statement *Statement) Init() { statement.HavingStr = "" statement.ColumnStr = "" statement.OmitStr = "" - statement.columnMap = make(map[string]bool) + statement.columnMap = columnMap{} + statement.omitColumnMap = columnMap{} statement.AltTableName = "" statement.tableName = "" statement.idParam = nil @@ -221,34 +207,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error { if err != nil { return err } - statement.tableName = statement.Engine.tbName(v) + statement.tableName = statement.Engine.TableName(v, true) return nil } -// Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - v := rValue(tableNameOrBean) - t := v.Type() - if t.Kind() == reflect.String { - statement.AltTableName = tableNameOrBean.(string) - } else if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } - statement.AltTableName = statement.Engine.tbName(v) +func (statement *Statement) setRefBean(bean interface{}) error { + var err error + statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) + if err != nil { + return err } - return statement + statement.tableName = statement.Engine.TableName(bean, true) + return nil } // Auto generating update columnes and values according a struct -func buildUpdates(engine *Engine, table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, - mustColumnMap map[string]bool, nullableMap map[string]bool, - columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { +func (statement *Statement) buildUpdates(bean interface{}, + includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) ([]string, []interface{}) { + engine := statement.Engine + table := statement.RefTable + allUseBool := statement.allUseBool + useAllCols := statement.useAllCols + mustColumnMap := statement.mustColumnMap + nullableMap := statement.nullableMap + columnMap := statement.columnMap + omitColumnMap := statement.omitColumnMap + unscoped := statement.unscoped var colNames = make([]string, 0) var args = make([]interface{}, 0) @@ -268,7 +253,14 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if col.IsDeleted && !unscoped { continue } - if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { + if omitColumnMap.contain(col.Name) { + continue + } + if len(columnMap) > 0 && !columnMap.contain(col.Name) { + continue + } + + if col.MapType == core.ONLYFROMDB { continue } @@ -604,17 +596,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { } func (statement *Statement) colmap2NewColsWithQuote() []string { - newColumns := make([]string, 0, len(statement.columnMap)) - for col := range statement.columnMap { - fields := strings.Split(strings.TrimSpace(col), ".") - if len(fields) == 1 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])) - } else if len(fields) == 2 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ - statement.Engine.quote(fields[1])) - } else { - panic(errors.New("unwanted colnames")) - } + newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) + copy(newColumns, statement.columnMap) + for i := 0; i < len(statement.columnMap); i++ { + newColumns[i] = statement.Engine.Quote(newColumns[i]) } return newColumns } @@ -642,10 +627,11 @@ func (statement *Statement) Select(str string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement { cols := col2NewCols(columns...) for _, nc := range cols { - statement.columnMap[strings.ToLower(nc)] = true + statement.columnMap.add(nc) } newColumns := statement.colmap2NewColsWithQuote() + statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) return statement @@ -680,7 +666,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { func (statement *Statement) Omit(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.columnMap[strings.ToLower(nc)] = false + statement.omitColumnMap = append(statement.omitColumnMap, nc) } statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } @@ -719,10 +705,9 @@ func (statement *Statement) OrderBy(order string) *Statement { // Desc generate `ORDER BY xx DESC` func (statement *Statement) Desc(colNames ...string) *Statement { - var buf bytes.Buffer - fmt.Fprintf(&buf, statement.OrderStr) + var buf builder.StringBuilder if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, ", ") + fmt.Fprint(&buf, statement.OrderStr, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) @@ -732,10 +717,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement { // Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { - var buf bytes.Buffer - fmt.Fprintf(&buf, statement.OrderStr) + var buf builder.StringBuilder if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, ", ") + fmt.Fprint(&buf, statement.OrderStr, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) @@ -743,48 +727,35 @@ func (statement *Statement) Asc(colNames ...string) *Statement { return statement } +// Table tempororily set table name, the parameter could be a string or a pointer of struct +func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { + v := rValue(tableNameOrBean) + t := v.Type() + if t.Kind() == reflect.Struct { + var err error + statement.RefTable, err = statement.Engine.autoMapType(v) + if err != nil { + statement.Engine.logger.Error(err) + return statement + } + } + + statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) + return statement +} + // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf bytes.Buffer + var buf builder.StringBuilder if len(statement.JoinStr) > 0 { fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) } else { fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) - } else if len(t) == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.String { - table = f.(string) - } else if t.Kind() == reflect.Struct { - table = statement.Engine.tbName(v) - } - } - if l > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), - statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(table)) - } - default: - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) - } + tbName := statement.Engine.TableName(tablename, true) - fmt.Fprintf(&buf, " ON %v", condition) + fmt.Fprintf(&buf, "%s ON %v", tbName, condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) return statement @@ -809,18 +780,20 @@ func (statement *Statement) Unscoped() *Statement { } func (statement *Statement) genColumnStr() string { - var buf bytes.Buffer if statement.RefTable == nil { return "" } + var buf builder.StringBuilder columns := statement.RefTable.Columns() for _, col := range columns { - if statement.OmitStr != "" { - if _, ok := getFlagForColumn(statement.columnMap, col); ok { - continue - } + if statement.omitColumnMap.contain(col.Name) { + continue + } + + if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { + continue } if col.MapType == core.ONLYTODB { @@ -831,10 +804,6 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - buf.WriteString("id() AS ") - } - if statement.JoinStr != "" { if statement.TableAlias != "" { buf.WriteString(statement.TableAlias) @@ -859,11 +828,13 @@ func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genIndexSQL() []string { var sqls []string tbName := statement.TableName() - quote := statement.Engine.Quote - for idxName, index := range statement.RefTable.Indexes { + for _, index := range statement.RefTable.Indexes { if index.Type == core.IndexType { - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + sql := statement.Engine.dialect.CreateIndexSql(tbName, index) + /*idxTBName := strings.Replace(tbName, ".", "_", -1) + idxTBName = strings.Replace(idxTBName, `"`, "", -1) + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), + quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ sqls = append(sqls, sql) } } @@ -889,16 +860,18 @@ func (statement *Statement) genUniqueSQL() []string { func (statement *Statement) genDelIndexSQL() []string { var sqls []string tbName := statement.TableName() + idxPrefixName := strings.Replace(tbName, `"`, "", -1) + idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) for idxName, index := range statement.RefTable.Indexes { var rIdxName string if index.Type == core.UniqueType { - rIdxName = uniqueName(tbName, idxName) + rIdxName = uniqueName(idxPrefixName, idxName) } else if index.Type == core.IndexType { - rIdxName = indexName(tbName, idxName) + rIdxName = indexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) + sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) + sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) } sqls = append(sqls, sql) } @@ -949,7 +922,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { - statement.setRefValue(v) + statement.setRefBean(bean) } var columnStr = statement.ColumnStr @@ -982,13 +955,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, if err := statement.mergeConds(bean); err != nil { return "", nil, err } + } else { + if err := statement.processIDParam(); err != nil { + return "", nil, err + } } condSQL, condArgs, err := builder.ToSQL(statement.cond) if err != nil { return "", nil, err } - sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) if err != nil { return "", nil, err } @@ -1001,7 +978,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa var condArgs []interface{} var err error if len(beans) > 0 { - statement.setRefValue(rValue(beans[0])) + statement.setRefBean(beans[0]) condSQL, condArgs, err = statement.genConds(beans[0]) } else { condSQL, condArgs, err = builder.ToSQL(statement.cond) @@ -1018,7 +995,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) if err != nil { return "", nil, err } @@ -1027,7 +1004,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa } func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefValue(rValue(bean)) + statement.setRefBean(bean) var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns { @@ -1043,7 +1020,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return "", nil, err } - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) if err != nil { return "", nil, err } @@ -1051,27 +1028,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { - var distinct string +func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { + var ( + distinct string + dialect = statement.Engine.Dialect() + quote = statement.Engine.Quote + fromStr = " FROM " + top, mssqlCondi, whereStr string + ) if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } - - var dialect = statement.Engine.Dialect() - var quote = statement.Engine.Quote - var top string - var mssqlCondi string - - if err := statement.processIDParam(); err != nil { - return "", err - } - - var buf bytes.Buffer if len(condSQL) > 0 { - fmt.Fprintf(&buf, " WHERE %v", condSQL) + whereStr = " WHERE " + condSQL } - var whereStr = buf.String() - var fromStr = " FROM " if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { fromStr += statement.TableName() @@ -1118,9 +1088,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e } var orderStr string - if len(statement.OrderStr) > 0 { + if needOrderBy && len(statement.OrderStr) > 0 { orderStr = " ORDER BY " + statement.OrderStr } + var groupStr string if len(statement.GroupByStr) > 0 { groupStr = " GROUP BY " + statement.GroupByStr @@ -1130,45 +1101,50 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e } } - // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + var buf builder.StringBuilder + fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if len(mssqlCondi) > 0 { if len(whereStr) > 0 { - a += " AND " + mssqlCondi + fmt.Fprint(&buf, " AND ", mssqlCondi) } else { - a += " WHERE " + mssqlCondi + fmt.Fprint(&buf, " WHERE ", mssqlCondi) } } if statement.GroupByStr != "" { - a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) + fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) } if statement.HavingStr != "" { - a = fmt.Sprintf("%v %v", a, statement.HavingStr) + fmt.Fprint(&buf, " ", statement.HavingStr) } - if statement.OrderStr != "" { - a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + if needOrderBy && statement.OrderStr != "" { + fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) } - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { - if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) - } - } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || statement.LimitN != 0 { - a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + if needLimit { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { + if statement.Start > 0 { + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) + } else if statement.LimitN > 0 { + fmt.Fprint(&buf, " LIMIT ", statement.LimitN) + } + } else if dialect.DBType() == core.ORACLE { + if statement.Start != 0 || statement.LimitN != 0 { + oldString := buf.String() + buf.Reset() + fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start) + } } } if statement.IsForUpdate { - a = dialect.ForUpdateSql(a) + return dialect.ForUpdateSql(buf.String()), nil } - return + return buf.String(), nil } func (statement *Statement) processIDParam() error { - if statement.idParam == nil { + if statement.idParam == nil || statement.RefTable == nil { return nil } diff --git a/vendor/github.com/go-xorm/xorm/xorm.go b/vendor/github.com/go-xorm/xorm/xorm.go index 4fdadf2fa..141c4897d 100644 --- a/vendor/github.com/go-xorm/xorm/xorm.go +++ b/vendor/github.com/go-xorm/xorm/xorm.go @@ -17,7 +17,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.6.4.0910" + Version string = "0.7.0.0504" ) func regDrvsNDialects() bool { @@ -31,7 +31,7 @@ func regDrvsNDialects() bool { "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, - "pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, + "pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, @@ -90,6 +90,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { TagIdentifier: "xorm", TZLocation: time.Local, tagHandlers: defaultTagHandlers, + cachers: make(map[string]core.Cacher), } if uri.DbType == core.SQLITE { @@ -108,6 +109,13 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return engine, nil } +// NewEngineWithParams new a db manager with params. The params will be passed to dialect. +func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { + engine, err := NewEngine(driverName, dataSourceName) + engine.dialect.SetParams(params) + return engine, err +} + // Clone clone an engine func (engine *Engine) Clone() (*Engine, error) { return NewEngine(engine.DriverName(), engine.DataSourceName())