update gopkg file to add sql dep

This commit is contained in:
Gitea 2018-06-25 14:48:13 -04:00
parent f4f9fad970
commit 8d4a2e1fe2
23 changed files with 923 additions and 1285 deletions

5
Gopkg.lock generated
View File

@ -295,7 +295,6 @@
name = "github.com/go-sql-driver/mysql" name = "github.com/go-sql-driver/mysql"
packages = ["."] packages = ["."]
revision = "d523deb1b23d913de5bdada721a6071e71283618" revision = "d523deb1b23d913de5bdada721a6071e71283618"
version = "v1.4.0"
[[projects]] [[projects]]
name = "github.com/go-xorm/builder" name = "github.com/go-xorm/builder"
@ -315,7 +314,7 @@
[[projects]] [[projects]]
name = "github.com/go-xorm/xorm" name = "github.com/go-xorm/xorm"
packages = ["."] packages = ["."]
revision = "a8bd843a55a7fdccb311d4fc51f8261e5035cdbb" revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -874,6 +873,6 @@
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "afb86c21ceae4915758e1160617fab3d6845bc092307e67bee1656efabac1ab1" inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -38,7 +38,11 @@ ignored = ["google.golang.org/appengine*"]
[[override]] [[override]]
name = "github.com/go-xorm/xorm" name = "github.com/go-xorm/xorm"
#version = "0.6.5" #version = "0.6.5"
revision = "a8bd843a55a7fdccb311d4fc51f8261e5035cdbb" revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03"
[[override]]
name = "github.com/go-sql-driver/mysql"
revision = "d523deb1b23d913de5bdada721a6071e71283618"
[[override]] [[override]]
name = "github.com/gorilla/mux" name = "github.com/gorilla/mux"

View File

@ -172,33 +172,12 @@ type mysql struct {
allowAllFiles bool allowAllFiles bool
allowOldPasswords bool allowOldPasswords bool
clientFoundRows bool clientFoundRows bool
rowFormat string
} }
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
func (db *mysql) SetParams(params map[string]string) {
rowFormat, ok := params["rowFormat"]
if ok {
var t = strings.ToUpper(rowFormat)
switch t {
case "COMPACT":
fallthrough
case "REDUNDANT":
fallthrough
case "DYNAMIC":
fallthrough
case "COMPRESSED":
db.rowFormat = t
break
default:
break
}
}
}
func (db *mysql) SqlType(c *core.Column) string { func (db *mysql) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
@ -508,59 +487,6 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += db.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
} else {
sql += col.StringNoPk(db)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if len(charset) == 0 {
charset = db.URI().Charset
} else if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
if db.rowFormat != "" {
sql += " ROW_FORMAT=" + db.rowFormat
}
return sql
}
func (db *mysql) Filters() []core.Filter { func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}} return []core.Filter{&core.IdFilter{}}
} }

View File

@ -769,21 +769,14 @@ var (
DefaultPostgresSchema = "public" DefaultPostgresSchema = "public"
) )
const postgresPublicSchema = "public"
type postgres struct { type postgres struct {
core.Base core.Base
schema string
} }
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName) db.schema = DefaultPostgresSchema
if err != nil { return db.Base.Init(d, db, uri, drivername, dataSourceName)
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
}
return nil
} }
func (db *postgres) SqlType(c *core.Column) string { func (db *postgres) SqlType(c *core.Column) string {
@ -880,42 +873,32 @@ func (db *postgres) IndexOnTable() bool {
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args `WHERE tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
if len(db.Schema) == 0 {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
args := []interface{}{db.Schema, tableName} /*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args args := []interface{}{tableName, colName}
} return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}*/
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
if len(db.Schema) == 0 {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col)) tableName, col.Name, db.SqlType(col))
} }
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)
if !strings.HasPrefix(idxName, "UQE_") && if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") { !strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
@ -924,21 +907,13 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{db.Schema, tableName, colName} args := []interface{}{tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $3"
if len(db.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2" " AND column_name = $2"
}
db.LogSQL(query, args) db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...) rows, err := db.DB().Query(query, args...)
@ -951,7 +926,8 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName} // FIXME: the schema should be replaced by user custom's
args := []interface{}{tableName, db.schema}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -962,15 +938,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1060,13 +1028,8 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
} }
func (db *postgres) GetTables() ([]*core.Table, error) { func (db *postgres) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{db.schema}
s := "SELECT tablename FROM pg_tables" s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " WHERE schemaname = $1"
}
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1090,12 +1053,8 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{db.schema, tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args) db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
@ -1223,15 +1182,3 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
return db, nil return db, nil
} }
type pqDriverPgx struct {
pqDriver
}
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) {
// Remove the leading characters for driver to work
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
dataSourceName = dataSourceName[9:]
}
return pgx.pqDriver.Parse(driverName, dataSourceName)
}

View File

@ -49,35 +49,6 @@ type Engine struct {
tagHandlers map[string]tagHandler tagHandlers map[string]tagHandler
engineGroup *EngineGroup engineGroup *EngineGroup
cachers map[string]core.Cacher
cacherLock sync.RWMutex
}
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
engine.cacherLock.Lock()
engine.cachers[tableName] = cacher
engine.cacherLock.Unlock()
}
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
engine.setCacher(tableName, cacher)
}
func (engine *Engine) getCacher(tableName string) core.Cacher {
var cacher core.Cacher
var ok bool
engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName]
engine.cacherLock.RUnlock()
if !ok && !engine.disableGlobalCache {
cacher = engine.Cacher
}
return cacher
}
func (engine *Engine) GetCacher(tableName string) core.Cacher {
return engine.getCacher(tableName)
} }
// BufferSize sets buffer size for iterate // BufferSize sets buffer size for iterate
@ -274,7 +245,13 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher // MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
engine.setCacher(engine.TableName(bean, true), cacher) v := rValue(bean)
tb, err := engine.autoMapType(v)
if err != nil {
return err
}
tb.Cacher = cacher
return nil return nil
} }
@ -559,6 +536,33 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil return nil
} }
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
}
if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return tb.TableName()
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name())
}
// Cascade use cascade or not // Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
@ -842,7 +846,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
} }
return &Table{tb, engine.TableName(bean)} return &Table{tb, engine.tbName(v)}
} }
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -857,6 +861,15 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i
} }
} }
func (engine *Engine) newTable() *core.Table {
table := core.NewEmptyTable()
if !engine.disableGlobalCache {
table.Cacher = engine.Cacher
}
return table
}
// TableName table name interface to define customerize table name // TableName table name interface to define customerize table name
type TableName interface { type TableName interface {
TableName() string TableName() string
@ -868,9 +881,21 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type() t := v.Type()
table := core.NewEmptyTable() table := engine.newTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}
table.Type = t table.Type = t
table.Name = engine.tbNameForMap(v)
var idFieldColName string var idFieldColName string
var hasCacheTag, hasNoCacheTag bool var hasCacheTag, hasNoCacheTag bool
@ -1024,15 +1049,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
if hasCacheTag { if hasCacheTag {
if engine.Cacher != nil { // !nash! use engine's cacher if provided if engine.Cacher != nil { // !nash! use engine's cacher if provided
engine.logger.Info("enable cache on table:", table.Name) engine.logger.Info("enable cache on table:", table.Name)
engine.setCacher(table.Name, engine.Cacher) table.Cacher = engine.Cacher
} else { } else {
engine.logger.Info("enable LRU cache on table:", table.Name) engine.logger.Info("enable LRU cache on table:", table.Name)
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now
} }
} }
if hasNoCacheTag { if hasNoCacheTag {
engine.logger.Info("disable cache on table:", table.Name) engine.logger.Info("no cache on table:", table.Name)
engine.setCacher(table.Name, nil) table.Cacher = nil
} }
return table, nil return table, nil
@ -1091,25 +1116,7 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
pk := make([]interface{}, len(table.PrimaryKeys)) pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
var err error var err error
pkField := v.FieldByName(col.FieldName)
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, ErrUnSupportedType
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() { switch pkField.Kind() {
case reflect.String: case reflect.String:
pk[i], err = engine.idTypeAssertion(col, pkField.String()) pk[i], err = engine.idTypeAssertion(col, pkField.String())
@ -1155,10 +1162,26 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean) return session.CreateUniques(bean)
} }
func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher
}
// ClearCacheBean if enabled cache, clear the cache bean // ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
tableName := engine.TableName(bean) v := rValue(bean)
cacher := engine.getCacher(tableName) t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.DelBean(tableName, id) cacher.DelBean(tableName, id)
@ -1169,8 +1192,21 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache // ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error { func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
tableName := engine.TableName(bean) v := rValue(bean)
cacher := engine.getCacher(tableName) t := v.Type()
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher
if cacher == nil {
cacher = engine.Cacher
}
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.ClearBeans(tableName) cacher.ClearBeans(tableName)
@ -1188,13 +1224,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableNameNoSchema := engine.TableName(bean) tableName := engine.tbName(v)
table, err := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil { if err != nil {
return err return err
} }
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) isExist, err := session.Table(bean).isTableExist(tableName)
if err != nil { if err != nil {
return err return err
} }
@ -1220,12 +1256,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
} else { } else {
for _, col := range table.Columns() { for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
@ -1236,35 +1272,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
for name, index := range table.Indexes { for name, index := range table.Indexes {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) isExist, err := session.isIndexExist2(tableName, index.Cols, true)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
err = session.addUnique(tableNameNoSchema, name) err = session.addUnique(tableName, name)
if err != nil { if err != nil {
return err return err
} }
} }
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) isExist, err := session.isIndexExist2(tableName, index.Cols, false)
if err != nil { if err != nil {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
err = session.addIndex(tableNameNoSchema, name) err = session.addIndex(tableName, name)
if err != nil { if err != nil {
return err return err
} }
@ -1417,13 +1453,6 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
return session.Find(beans, condiBeans...) return session.Find(beans, condiBeans...)
} }
// FindAndCount find the results and also return the counts
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.FindAndCount(rowsSlicePtr, condiBean...)
}
// Iterate record by record handle records from table, bean's non-empty fields // Iterate record by record handle records from table, bean's non-empty fields
// are conditions. // are conditions.
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {
@ -1600,11 +1629,6 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) {
engine.DatabaseTZ = tz engine.DatabaseTZ = tz
} }
// SetSchema sets the schema of database
func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().Schema = schema
}
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session { func (engine *Engine) Unscoped() *Session {
session := engine.NewSession() session := engine.NewSession()

View File

@ -9,7 +9,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
@ -52,9 +51,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
fieldValuePtr, err := col.ValueOf(bean) fieldValuePtr, err := col.ValueOf(bean)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "is not valid") { engine.logger.Error(err)
engine.logger.Warn(err)
}
continue continue
} }

View File

@ -1,113 +0,0 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"reflect"
"strings"
"github.com/go-xorm/core"
)
// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
}
return engine.TableMapper.Obj2Table(v.Type().Name())
}
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = engine.tbNameForMap(v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case TableName:
return tablename.(TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return engine.tbNameForMap(v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return engine.tbNameForMap(v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

View File

@ -6,44 +6,23 @@ package xorm
import ( import (
"errors" "errors"
"fmt"
) )
var ( var (
// ErrParamsType params error // ErrParamsType params error
ErrParamsType = errors.New("Params type error") ErrParamsType = errors.New("Params type error")
// ErrTableNotFound table not found error // ErrTableNotFound table not found error
ErrTableNotFound = errors.New("Table not found") ErrTableNotFound = errors.New("Not found table")
// ErrUnSupportedType unsupported error // ErrUnSupportedType unsupported error
ErrUnSupportedType = errors.New("Unsupported type error") ErrUnSupportedType = errors.New("Unsupported type error")
// ErrNotExist record does not exist error // ErrNotExist record is not exist error
ErrNotExist = errors.New("Record does not exist") ErrNotExist = errors.New("Not exist error")
// ErrCacheFailed cache failed error // ErrCacheFailed cache failed error
ErrCacheFailed = errors.New("Cache failed") ErrCacheFailed = errors.New("Cache failed")
// ErrNeedDeletedCond delete needs less one condition error // ErrNeedDeletedCond delete needs less one condition error
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") ErrNeedDeletedCond = errors.New("Delete need at least one condition")
// ErrNotImplemented not implemented // ErrNotImplemented not implemented
ErrNotImplemented = errors.New("Not implemented") ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported // ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type") ErrConditionType = errors.New("Unsupported conditon type")
) )
// ErrFieldIsNotExist columns does not exist
type ErrFieldIsNotExist struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotExist) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}
// ErrFieldIsNotValid is not valid
type ErrFieldIsNotValid struct {
FieldName string
TableName string
}
func (e ErrFieldIsNotValid) Error() string {
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName)
}

View File

@ -11,6 +11,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
@ -292,6 +293,19 @@ func structName(v reflect.Type) string {
return v.Name() return v.Name()
} }
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
func sliceEq(left, right []string) bool { func sliceEq(left, right []string) bool {
if len(left) != len(right) { if len(left) != len(right) {
return false return false
@ -306,6 +320,154 @@ func sliceEq(left, right []string) bool {
return true return true
} }
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if col.IsDeleted {
continue
}
if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
if includeQuote {
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
} else {
colNames = append(colNames, col.Name)
}
}
return colNames, args, nil
}
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
} }
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}

View File

@ -30,7 +30,6 @@ type Interface interface {
Exec(string, ...interface{}) (sql.Result, error) Exec(string, ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error) Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error)
Get(interface{}) (bool, error) Get(interface{}) (bool, error)
GroupBy(keys string) *Session GroupBy(keys string) *Session
ID(interface{}) *Session ID(interface{}) *Session
@ -42,7 +41,6 @@ type Interface interface {
IsTableExist(beanOrTableName interface{}) (bool, error) IsTableExist(beanOrTableName interface{}) (bool, error)
Iterate(interface{}, IterFunc) error Iterate(interface{}, IterFunc) error
Limit(int, ...int) *Session Limit(int, ...int) *Session
MustCols(columns ...string) *Session
NoAutoCondition(...bool) *Session NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session NotIn(string, ...interface{}) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
@ -77,7 +75,6 @@ type EngineInterface interface {
Dialect() core.Dialect Dialect() core.Dialect
DropTables(...interface{}) error DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error DumpAllToFile(fp string, tp ...core.DbType) error
GetCacher(string) core.Cacher
GetColumnMapper() core.IMapper GetColumnMapper() core.IMapper
GetDefaultCacher() core.Cacher GetDefaultCacher() core.Cacher
GetTableMapper() core.IMapper GetTableMapper() core.IMapper
@ -86,11 +83,9 @@ type EngineInterface interface {
NewSession() *Session NewSession() *Session
NoAutoTime() *Session NoAutoTime() *Session
Quote(string) string Quote(string) string
SetCacher(string, core.Cacher)
SetDefaultCacher(core.Cacher) SetDefaultCacher(core.Cacher)
SetLogLevel(core.LogLevel) SetLogLevel(core.LogLevel)
SetMapper(core.IMapper) SetMapper(core.IMapper)
SetSchema(string)
SetTZDatabase(tz *time.Location) SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location) SetTZLocation(tz *time.Location)
ShowSQL(show ...bool) ShowSQL(show ...bool)
@ -98,7 +93,6 @@ type EngineInterface interface {
Sync2(...interface{}) error Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table TableInfo(bean interface{}) *Table
TableName(interface{}, ...bool) string
UnMapType(reflect.Type) UnMapType(reflect.Type)
} }

View File

@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var args []interface{} var args []interface{}
var err error var err error
if err = rows.session.statement.setRefBean(bean); err != nil { if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
return nil, err return nil, err
} }
@ -94,7 +94,8 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
} }
if err := rows.session.statement.setRefBean(bean); err != nil { dataStruct := rValue(bean)
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
return err return err
} }
@ -103,7 +104,6 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
dataStruct := rValue(bean)
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil { if err != nil {
return err return err

View File

@ -278,22 +278,24 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
return return
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
var col *core.Column var col *core.Column
if col = table.GetColumnIdx(key, idx); col == nil { if col = table.GetColumnIdx(key, idx); col == nil {
return nil, ErrFieldIsNotExist{key, table.Name} //session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq())
return nil
} }
fieldValue, err := col.ValueOfV(dataStruct) fieldValue, err := col.ValueOfV(dataStruct)
if err != nil { if err != nil {
return nil, err session.engine.logger.Error(err)
return nil
} }
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
return nil, ErrFieldIsNotValid{key, table.Name} session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key)
return nil
} }
return fieldValue
return fieldValue, nil
} }
// Cell cell is a result of one column field // Cell cell is a result of one column field
@ -405,16 +407,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
tempMap[lKey] = idx tempMap[lKey] = idx
fieldValue, err := session.getField(dataStruct, key, table, idx) if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil {
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
session.engine.logger.Warn(err)
}
continue
}
if fieldValue == nil {
continue
}
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
// if row is null then ignore // if row is null then ignore
@ -819,6 +812,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
} }
} }
}
return pk, nil return pk, nil
} }
@ -834,6 +828,15 @@ func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs return session.lastSQL, session.lastSQLArgs
} }
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session { func (session *Session) Unscoped() *Session {
session.statement.Unscoped() session.statement.Unscoped()

View File

@ -4,121 +4,6 @@
package xorm package xorm
import (
"reflect"
"strings"
"time"
"github.com/go-xorm/core"
)
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
type columnMap []string
func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}
n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}
return false
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}
func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t)
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t))
}
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
}
if v.CanSet() {
switch v.Type().Kind() {
case reflect.Struct:
v.Set(reflect.ValueOf(t).Convert(v.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
v.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
v.SetUint(uint64(t.Unix()))
}
}
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Incr provides a query string like "count = count + 1" // Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session { func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.statement.Incr(column, arg...) session.statement.Incr(column, arg...)

View File

@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed return ErrCacheFailed
} }
cacher := session.engine.getCacher(tableName) cacher := session.engine.getCacher2(table)
pkColumns := table.PKColumns() pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err return 0, err
} }
@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}) })
} }
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
} }

View File

@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefBean(bean[0]); err != nil { if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err return false, err
} }
} }

View File

@ -29,39 +29,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.find(rowsSlicePtr, condiBean...) return session.find(rowsSlicePtr, condiBean...)
} }
// FindAndCount find the results and also return the counts
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
err := session.find(rowsSlicePtr, condiBean...)
if err != nil {
return 0, err
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return 0, errors.New("needs a pointer to a slice or a map")
}
sliceElementType := sliceValue.Type().Elem()
if sliceElementType.Kind() == reflect.Ptr {
sliceElementType = sliceElementType.Elem()
}
session.autoResetStatement = true
if session.statement.selectStr != "" {
session.statement.selectStr = ""
}
if session.statement.OrderStr != "" {
session.statement.OrderStr = ""
}
return session.Count(reflect.New(sliceElementType).Interface())
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@ -75,7 +42,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv); err != nil { if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err return err
} }
} else { } else {
@ -83,7 +50,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv); err != nil { if err := session.statement.setRefValue(pv.Elem()); err != nil {
return err return err
} }
} else { } else {
@ -161,7 +128,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
args = append(session.statement.joinArgs, condArgs...) args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
if err != nil { if err != nil {
return err return err
} }
@ -176,7 +143,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
if session.canCache() { if session.canCache() {
if cacher := session.engine.getCacher(table.Name); cacher != nil && if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.statement.IsDistinct && !session.statement.IsDistinct &&
!session.statement.unscoped { !session.statement.unscoped {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
@ -321,12 +288,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed return ErrCacheFailed
} }
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
if cacher == nil {
return nil
}
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
} }
@ -336,7 +297,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
return ErrCacheFailed return ErrCacheFailed
} }
tableName := session.statement.TableName()
table := session.statement.RefTable table := session.statement.RefTable
cacher := session.engine.getCacher2(table)
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
rows, err := session.queryRows(newsql, args...) rows, err := session.queryRows(newsql, args...)

View File

@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
return false, err return false, err
} }
} }
@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.getCacher(table.Name); cacher != nil && if cacher := session.engine.getCacher2(table); cacher != nil &&
!session.statement.unscoped { !session.statement.unscoped {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
@ -134,9 +134,8 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed return false, ErrCacheFailed
} }
cacher := session.engine.getCacher2(session.statement.RefTable)
tableName := session.statement.TableName() tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := core.GetCacheSql(cacher, tableName, newsql, args)

View File

@ -66,12 +66,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice") return 0, errors.New("could not insert a empty slice")
} }
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil {
return 0, err return 0, err
} }
tableName := session.statement.TableName() if len(session.statement.TableName()) <= 0 {
if len(tableName) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
@ -116,12 +115,16 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} }
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { }
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue continue
} }
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
args = append(args, val) args = append(args, val)
@ -167,12 +170,16 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.ColumnStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok {
continue continue
} }
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { }
if session.statement.OmitStr != "" {
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok {
continue continue
} }
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
args = append(args, val) args = append(args, val)
@ -206,6 +213,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
var statement string var statement string
var tableName = session.statement.TableName()
if session.engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == core.ORACLE {
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
@ -232,7 +240,9 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err return 0, err
} }
session.cacheInsert(tableName) if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
lenAfterClosures := len(session.afterClosures) lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@ -288,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} }
func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(rValue(bean)); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
@ -306,8 +316,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert() processor.BeforeInsert()
} }
// --
colNames, args, err := session.genInsertColumns(bean) colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -392,7 +402,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName) if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -435,7 +447,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName) if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -476,7 +490,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
defer handleAfterInsertProcessorFunc(bean) defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName) if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
session.cacheInsert(table, tableName)
}
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
@ -523,104 +539,16 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return session.innerInsert(bean) return session.innerInsert(bean)
} }
func (session *Session) cacheInsert(table string) error { func (session *Session) cacheInsert(table *core.Table, tables ...string) error {
if !session.statement.UseCache { if table == nil {
return ErrCacheFailed
}
cacher := session.engine.getCacher2(table)
for _, t := range tables {
session.engine.logger.Debug("[cache] clear sql:", t)
cacher.ClearIds(t)
}
return nil return nil
} }
cacher := session.engine.getCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debug("[cache] clear sql:", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == core.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}
if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}

View File

@ -64,17 +64,13 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
} }
} }
if err := session.statement.processIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
args := append(session.statement.joinArgs, condArgs...) args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -6,7 +6,9 @@ package xorm
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"reflect"
"strings" "strings"
"github.com/go-xorm/core" "github.com/go-xorm/core"
@ -32,7 +34,8 @@ func (session *Session) CreateTable(bean interface{}) error {
} }
func (session *Session) createTable(bean interface{}) error { func (session *Session) createTable(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -51,7 +54,8 @@ func (session *Session) CreateIndexes(bean interface{}) error {
} }
func (session *Session) createIndexes(bean interface{}) error { func (session *Session) createIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -74,7 +78,8 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
func (session *Session) createUniques(bean interface{}) error { func (session *Session) createUniques(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -98,7 +103,8 @@ func (session *Session) DropIndexes(bean interface{}) error {
} }
func (session *Session) dropIndexes(bean interface{}) error { func (session *Session) dropIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
return err return err
} }
@ -122,7 +128,11 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName) tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@ -134,8 +144,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
} }
if needDrop { if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err := session.exec(sqlStr) _, err = session.exec(sqlStr)
return err return err
} }
return nil return nil
@ -147,7 +157,10 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close() defer session.Close()
} }
tableName := session.engine.TableName(beanOrTableName) tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
return session.isTableExist(tableName) return session.isTableExist(tableName)
} }
@ -160,15 +173,24 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
// IsTableEmpty if table have any records // IsTableEmpty if table have any records
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean)
t := v.Type()
if t.Kind() == reflect.String {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
return session.isTableEmpty(session.engine.TableName(bean)) return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
return rows == 0, err
}
return false, errors.New("bean should be a struct or struct's point")
} }
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
err := session.queryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -233,12 +255,6 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err return err
} }
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var structTables []*core.Table var structTables []*core.Table
for _, bean := range beans { for _, bean := range beans {
@ -248,8 +264,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err return err
} }
structTables = append(structTables, table) structTables = append(structTables, table)
tbName := engine.TableName(bean) var tbName = session.tbNameNoSchema(table)
tbNameWithSchema := engine.TableName(tbName, true)
var oriTable *core.Table var oriTable *core.Table
for _, tb := range tables { for _, tb := range tables {
@ -294,32 +309,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL || if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES { engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbName, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbName, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} }
} }
} else { } else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType) tbName, col.Name, curType, expectedType)
} }
} }
} else if expectedType == core.Varchar { } else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
} }
} }
} }
@ -333,7 +348,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
} else { } else {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.tableName = tbName
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
} }
if err != nil { if err != nil {
@ -356,7 +371,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) sql := engine.dialect.DropIndexSql(tbName, oriIndex)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -372,7 +387,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) sql := engine.dialect.DropIndexSql(tbName, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -383,12 +398,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.tableName = tbName
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbName, name)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.tableName = tbName
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbName, name)
} }
if err != nil { if err != nil {
return err return err
@ -413,7 +428,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil { if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
} }
} }
} }

View File

@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
} }
} }
cacher := session.engine.getCacher(tableName) cacher := session.engine.getCacher2(table)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil { if err != nil {
@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct var isStruct = t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefValue(v); err != nil {
return 0, err return 0, err
} }
@ -176,10 +176,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if session.statement.ColumnStr == "" { if session.statement.ColumnStr == "" {
colNames, args = session.statement.buildUpdates(bean, false, false, colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false,
false, false, true) false, false, session.statement.allUseBool, session.statement.useAllCols,
session.statement.mustColumnMap, session.statement.nullableMap,
session.statement.columnMap, true, session.statement.unscoped)
} else { } else {
colNames, args, err = session.genUpdateColumns(bean) colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -200,8 +202,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table := session.statement.RefTable table := session.statement.RefTable
if session.statement.UseAutoTime && table != nil && table.Updated != "" { if session.statement.UseAutoTime && table != nil && table.Updated != "" {
if !session.statement.columnMap.contain(table.Updated) && if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok {
!session.statement.omitColumnMap.contain(table.Updated) {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn() col := table.UpdatedColumn()
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
@ -361,12 +362,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { if table != nil {
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...) //session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.ClearBeans(tableName) cacher.ClearBeans(tableName)
} }
}
// handle after update processors // handle after update processors
if session.isAutoCommit { if session.isAutoCommit {
@ -400,92 +402,3 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return res.RowsAffected() return res.RowsAffected()
} }
func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if session.statement.omitColumnMap.contain(col.Name) {
continue
}
}
if col.MapType == core.ONLYFROMDB {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
}
if col.IsDeleted || col.IsCreated {
continue
}
if len(session.statement.columnMap) > 0 {
if !session.statement.columnMap.contain(col.Name) {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
if col.Nullable && isZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t := session.engine.nowTime(col)
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.checkVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
}
return colNames, args, nil
}

View File

@ -18,6 +18,21 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *core.Table RefTable *core.Table
@ -32,6 +47,7 @@ type Statement struct {
HavingStr string HavingStr string
ColumnStr string ColumnStr string
selectStr string selectStr string
columnMap map[string]bool
useAllCols bool useAllCols bool
OmitStr string OmitStr string
AltTableName string AltTableName string
@ -51,8 +67,6 @@ type Statement struct {
allUseBool bool allUseBool bool
checkVersion bool checkVersion bool
unscoped bool unscoped bool
columnMap columnMap
omitColumnMap columnMap
mustColumnMap map[string]bool mustColumnMap map[string]bool
nullableMap map[string]bool nullableMap map[string]bool
incrColumns map[string]incrParam incrColumns map[string]incrParam
@ -75,8 +89,7 @@ func (statement *Statement) Init() {
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.OmitStr = "" statement.OmitStr = ""
statement.columnMap = columnMap{} statement.columnMap = make(map[string]bool)
statement.omitColumnMap = columnMap{}
statement.AltTableName = "" statement.AltTableName = ""
statement.tableName = "" statement.tableName = ""
statement.idParam = nil statement.idParam = nil
@ -208,33 +221,34 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil { if err != nil {
return err return err
} }
statement.tableName = statement.Engine.TableName(v, true) statement.tableName = statement.Engine.tbName(v)
return nil return nil
} }
func (statement *Statement) setRefBean(bean interface{}) error { // Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil { if err != nil {
return err statement.Engine.logger.Error(err)
return statement
} }
statement.tableName = statement.Engine.TableName(bean, true) statement.AltTableName = statement.Engine.tbName(v)
return nil }
return statement
} }
// Auto generating update columnes and values according a struct // Auto generating update columnes and values according a struct
func (statement *Statement) buildUpdates(bean interface{}, func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion, includeUpdated, includeNil, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr, update bool) ([]string, []interface{}) { includeAutoIncr bool, allUseBool bool, useAllCols bool,
engine := statement.Engine mustColumnMap map[string]bool, nullableMap map[string]bool,
table := statement.RefTable columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.mustColumnMap
nullableMap := statement.nullableMap
columnMap := statement.columnMap
omitColumnMap := statement.omitColumnMap
unscoped := statement.unscoped
var colNames = make([]string, 0) var colNames = make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
@ -254,10 +268,7 @@ func (statement *Statement) buildUpdates(bean interface{},
if col.IsDeleted && !unscoped { if col.IsDeleted && !unscoped {
continue continue
} }
if omitColumnMap.contain(col.Name) { if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use {
continue
}
if len(columnMap) > 0 && !columnMap.contain(col.Name) {
continue continue
} }
@ -593,10 +604,17 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
} }
func (statement *Statement) colmap2NewColsWithQuote() []string { func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) newColumns := make([]string, 0, len(statement.columnMap))
copy(newColumns, statement.columnMap) for col := range statement.columnMap {
for i := 0; i < len(statement.columnMap); i++ { fields := strings.Split(strings.TrimSpace(col), ".")
newColumns[i] = statement.Engine.Quote(newColumns[i]) if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
} }
return newColumns return newColumns
} }
@ -624,11 +642,10 @@ func (statement *Statement) Select(str string) *Statement {
func (statement *Statement) Cols(columns ...string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...) cols := col2NewCols(columns...)
for _, nc := range cols { for _, nc := range cols {
statement.columnMap.add(nc) statement.columnMap[strings.ToLower(nc)] = true
} }
newColumns := statement.colmap2NewColsWithQuote() newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
return statement return statement
@ -663,7 +680,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
func (statement *Statement) Omit(columns ...string) { func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
statement.omitColumnMap = append(statement.omitColumnMap, nc) statement.columnMap[strings.ToLower(nc)] = false
} }
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
@ -726,23 +743,6 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement return statement
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer var buf bytes.Buffer
@ -752,9 +752,39 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
tbName := statement.Engine.TableName(tablename, true) switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.String {
table = f.(string)
} else if t.Kind() == reflect.Struct {
table = statement.Engine.tbName(v)
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
default:
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) fmt.Fprintf(&buf, " ON %v", condition)
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...) statement.joinArgs = append(statement.joinArgs, args...)
return statement return statement
@ -787,12 +817,10 @@ func (statement *Statement) genColumnStr() string {
columns := statement.RefTable.Columns() columns := statement.RefTable.Columns()
for _, col := range columns { for _, col := range columns {
if statement.omitColumnMap.contain(col.Name) { if statement.OmitStr != "" {
if _, ok := getFlagForColumn(statement.columnMap, col); ok {
continue continue
} }
if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
continue
} }
if col.MapType == core.ONLYTODB { if col.MapType == core.ONLYTODB {
@ -803,6 +831,10 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ") buf.WriteString(", ")
} }
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
buf.WriteString("id() AS ")
}
if statement.JoinStr != "" { if statement.JoinStr != "" {
if statement.TableAlias != "" { if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias) buf.WriteString(statement.TableAlias)
@ -827,13 +859,11 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string { func (statement *Statement) genIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes { quote := statement.Engine.Quote
for idxName, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType { if index.Type == core.IndexType {
sql := statement.Engine.dialect.CreateIndexSql(tbName, index) sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
/*idxTBName := strings.Replace(tbName, ".", "_", -1) quote(tbName), quote(strings.Join(index.Cols, quote(","))))
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -859,18 +889,16 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string { func (statement *Statement) genDelIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes { for idxName, index := range statement.RefTable.Indexes {
var rIdxName string var rIdxName string
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
rIdxName = uniqueName(idxPrefixName, idxName) rIdxName = uniqueName(tbName, idxName)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
rIdxName = indexName(idxPrefixName, idxName) rIdxName = indexName(tbName, idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
if statement.Engine.dialect.IndexOnTable() { if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
} }
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
@ -921,7 +949,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
v := rValue(bean) v := rValue(bean)
isStruct := v.Kind() == reflect.Struct isStruct := v.Kind() == reflect.Struct
if isStruct { if isStruct {
statement.setRefBean(bean) statement.setRefValue(v)
} }
var columnStr = statement.ColumnStr var columnStr = statement.ColumnStr
@ -954,17 +982,13 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if err := statement.mergeConds(bean); err != nil { if err := statement.mergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
} else {
if err := statement.processIDParam(); err != nil {
return "", nil, err
}
} }
condSQL, condArgs, err := builder.ToSQL(statement.cond) condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -977,7 +1001,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var condArgs []interface{} var condArgs []interface{}
var err error var err error
if len(beans) > 0 { if len(beans) > 0 {
statement.setRefBean(beans[0]) statement.setRefValue(rValue(beans[0]))
condSQL, condArgs, err = statement.genConds(beans[0]) condSQL, condArgs, err = statement.genConds(beans[0])
} else { } else {
condSQL, condArgs, err = builder.ToSQL(statement.cond) condSQL, condArgs, err = builder.ToSQL(statement.cond)
@ -994,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1003,7 +1027,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
} }
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefBean(bean) statement.setRefValue(rValue(bean))
var sumStrs = make([]string, 0, len(columns)) var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
@ -1019,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -1027,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) { func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
var distinct string var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
@ -1038,6 +1062,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
var top string var top string
var mssqlCondi string var mssqlCondi string
if err := statement.processIDParam(); err != nil {
return "", err
}
var buf bytes.Buffer var buf bytes.Buffer
if len(condSQL) > 0 { if len(condSQL) > 0 {
fmt.Fprintf(&buf, " WHERE %v", condSQL) fmt.Fprintf(&buf, " WHERE %v", condSQL)
@ -1090,10 +1118,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} }
var orderStr string var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr orderStr = " ORDER BY " + statement.OrderStr
} }
var groupStr string var groupStr string
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr groupStr = " GROUP BY " + statement.GroupByStr
@ -1119,10 +1146,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
if statement.HavingStr != "" { if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr) a = fmt.Sprintf("%v %v", a, statement.HavingStr)
} }
if needOrderBy && statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
@ -1134,7 +1160,6 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
} }
} }
}
if statement.IsForUpdate { if statement.IsForUpdate {
a = dialect.ForUpdateSql(a) a = dialect.ForUpdateSql(a)
} }
@ -1143,7 +1168,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} }
func (statement *Statement) processIDParam() error { func (statement *Statement) processIDParam() error {
if statement.idParam == nil || statement.RefTable == nil { if statement.idParam == nil {
return nil return nil
} }

View File

@ -17,7 +17,7 @@ import (
const ( const (
// Version show the xorm's version // Version show the xorm's version
Version string = "0.7.0.0504" Version string = "0.6.4.0910"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {
@ -31,7 +31,7 @@ func regDrvsNDialects() bool {
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, "pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
@ -90,7 +90,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
TagIdentifier: "xorm", TagIdentifier: "xorm",
TZLocation: time.Local, TZLocation: time.Local,
tagHandlers: defaultTagHandlers, tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
} }
if uri.DbType == core.SQLITE { if uri.DbType == core.SQLITE {
@ -109,13 +108,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return engine, nil return engine, nil
} }
// NewEngineWithParams new a db manager with params. The params will be passed to dialect.
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) {
engine, err := NewEngine(driverName, dataSourceName)
engine.dialect.SetParams(params)
return engine, err
}
// Clone clone an engine // Clone clone an engine
func (engine *Engine) Clone() (*Engine, error) { func (engine *Engine) Clone() (*Engine, error) {
return NewEngine(engine.DriverName(), engine.DataSourceName()) return NewEngine(engine.DriverName(), engine.DataSourceName())