Beispiel #1
1
func (db *mysql) SqlType(c *core.Column) string {
	var res string
	switch t := c.SQLType.Name; t {
	case core.Bool:
		res = core.TinyInt
		c.Length = 1
	case core.Serial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.Int
	case core.BigSerial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.BigInt
	case core.Bytea:
		res = core.Blob
	case core.TimeStampz:
		res = core.Char
		c.Length = 64
	default:
		res = t
	}

	var hasLen1 bool = (c.Length > 0)
	var hasLen2 bool = (c.Length2 > 0)
	if hasLen1 {
		res += "(" + strconv.Itoa(c.Length) + ")"
	} else if hasLen2 {
		res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
	}
	return res
}
Beispiel #2
0
func (db *sqlite3) SqlType(c *core.Column) string {
	switch t := c.SQLType.Name; t {
	case core.Bool:
		if c.Default == "true" {
			c.Default = "1"
		} else if c.Default == "false" {
			c.Default = "0"
		}
		return core.Integer
	case core.Date, core.DateTime, core.TimeStamp, core.Time:
		return core.DateTime
	case core.TimeStampz:
		return core.Text
	case core.Char, core.Varchar, core.NVarchar, core.TinyText,
		core.Text, core.MediumText, core.LongText, core.Json:
		return core.Text
	case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt:
		return core.Integer
	case core.Float, core.Double, core.Real:
		return core.Real
	case core.Decimal, core.Numeric:
		return core.Numeric
	case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary:
		return core.Blob
	case core.Serial, core.BigSerial:
		c.IsPrimaryKey = true
		c.IsAutoIncrement = true
		c.Nullable = false
		return core.Integer
	default:
		return t
	}
}
Beispiel #3
0
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{}
	s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale
from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id
where a.object_id=object_id('` + tableName + `')`
	db.LogSQL(s, args)

	rows, err := db.DB().Query(s, args...)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for rows.Next() {
		var name, ctype, precision, scale string
		var maxLen int
		err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale)
		if err != nil {
			return nil, nil, err
		}

		col := new(core.Column)
		col.Indexes = make(map[string]bool)
		col.Length = maxLen
		col.Name = strings.Trim(name, "` ")

		ct := strings.ToUpper(ctype)
		switch ct {
		case "DATETIMEOFFSET":
			col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
		case "NVARCHAR":
			col.SQLType = core.SQLType{core.NVarchar, 0, 0}
		case "IMAGE":
			col.SQLType = core.SQLType{core.VarBinary, 0, 0}
		default:
			if _, ok := core.SqlTypes[ct]; ok {
				col.SQLType = core.SQLType{ct, 0, 0}
			} else {
				return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v",
					ct, tableName, col.Name))
			}
		}

		if col.SQLType.IsText() || col.SQLType.IsTime() {
			if col.Default != "" {
				col.Default = "'" + col.Default + "'"
			} else {
				if col.DefaultIsEmpty {
					col.Default = "''"
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}
	return colSeq, cols, nil
}
Beispiel #4
0
func (db *mssql) SqlType(c *core.Column) string {
	var res string
	switch t := c.SQLType.Name; t {
	case core.Bool:
		res = core.TinyInt
		if c.Default == "true" {
			c.Default = "1"
		} else if c.Default == "false" {
			c.Default = "0"
		}
	case core.Serial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.Int
	case core.BigSerial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.BigInt
	case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob:
		res = core.VarBinary
		if c.Length == 0 {
			c.Length = 50
		}
	case core.TimeStamp:
		res = core.DateTime
	case core.TimeStampz:
		res = "DATETIMEOFFSET"
		c.Length = 7
	case core.MediumInt:
		res = core.Int
	case core.MediumText, core.TinyText, core.LongText, core.Json:
		res = core.Text
	case core.Double:
		res = core.Real
	default:
		res = t
	}

	if res == core.Int {
		return core.Int
	}

	var hasLen1 bool = (c.Length > 0)
	var hasLen2 bool = (c.Length2 > 0)
	if hasLen2 {
		res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
	} else if hasLen1 {
		res += "(" + strconv.Itoa(c.Length) + ")"
	}
	return res
}
Beispiel #5
0
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{tableName}
	s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"

	rows, err := db.DB().Query(s, args...)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	var name string
	for rows.Next() {
		err = rows.Scan(&name)
		if err != nil {
			return nil, nil, err
		}
		break
	}

	if name == "" {
		return nil, nil, errors.New("no table named " + tableName)
	}

	nStart := strings.Index(name, "(")
	nEnd := strings.Index(name, ")")
	colCreates := strings.Split(name[nStart+1:nEnd], ",")
	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for _, colStr := range colCreates {
		fields := strings.Fields(strings.TrimSpace(colStr))
		col := new(core.Column)
		col.Indexes = make(map[string]bool)
		col.Nullable = true
		for idx, field := range fields {
			if idx == 0 {
				col.Name = strings.Trim(field, "`[] ")
				continue
			} else if idx == 1 {
				col.SQLType = core.SQLType{field, 0, 0}
			}
			switch field {
			case "PRIMARY":
				col.IsPrimaryKey = true
			case "AUTOINCREMENT":
				col.IsAutoIncrement = true
			case "NULL":
				if fields[idx-1] == "NOT" {
					col.Nullable = false
				} else {
					col.Nullable = true
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}
	return colSeq, cols, nil
}
Beispiel #6
0
func (db *mysql) SqlType(c *core.Column) string {
	var res string
	switch t := c.SQLType.Name; t {
	case core.Bool:
		res = core.TinyInt
		c.Length = 1
	case core.Serial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.Int
	case core.BigSerial:
		c.IsAutoIncrement = true
		c.IsPrimaryKey = true
		c.Nullable = false
		res = core.BigInt
	case core.Bytea:
		res = core.Blob
	case core.TimeStampz:
		res = core.Char
		c.Length = 64
	case core.Enum: //mysql enum
		res = core.Enum
		res += "("
		opts := ""
		for v, _ := range c.EnumOptions {
			opts += fmt.Sprintf(",'%v'", v)
		}
		res += strings.TrimLeft(opts, ",")
		res += ")"
	case core.Set: //mysql set
		res = core.Set
		res += "("
		opts := ""
		for v, _ := range c.SetOptions {
			opts += fmt.Sprintf(",'%v'", v)
		}
		res += strings.TrimLeft(opts, ",")
		res += ")"
	default:
		res = t
	}

	var hasLen1 bool = (c.Length > 0)
	var hasLen2 bool = (c.Length2 > 0)

	if res == core.BigInt && !hasLen1 && !hasLen2 {
		c.Length = 20
		hasLen1 = true
	}

	if hasLen2 {
		res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
	} else if hasLen1 {
		res += "(" + strconv.Itoa(c.Length) + ")"
	}
	return res
}
Beispiel #7
0
func setColumnInt(bean interface{}, col *core.Column, t int64) {
	v, err := col.ValueOf(bean)
	if err != nil {
		return
	}
	if v.CanSet() {
		switch v.Type().Kind() {
		case reflect.Int, reflect.Int64, reflect.Int32:
			v.SetInt(t)
		case reflect.Uint, reflect.Uint64, reflect.Uint32:
			v.SetUint(uint64(t))
		}
	}
}
Beispiel #8
0
func (db *postgres) SqlType(c *core.Column) string {
	var res string
	switch t := c.SQLType.Name; t {
	case core.TinyInt:
		res = core.SmallInt
		return res
	case core.MediumInt, core.Int, core.Integer:
		if c.IsAutoIncrement {
			return core.Serial
		}
		return core.Integer
	case core.Serial, core.BigSerial:
		c.IsAutoIncrement = true
		c.Nullable = false
		res = t
	case core.Binary, core.VarBinary:
		return core.Bytea
	case core.DateTime:
		res = core.TimeStamp
	case core.TimeStampz:
		return "timestamp with time zone"
	case core.Float:
		res = core.Real
	case core.TinyText, core.MediumText, core.LongText:
		res = core.Text
	case core.NVarchar:
		res = core.Varchar
	case core.Uuid:
		res = core.Uuid
	case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
		return core.Bytea
	case core.Double:
		return "DOUBLE PRECISION"
	default:
		if c.IsAutoIncrement {
			return core.Serial
		}
		res = t
	}

	var hasLen1 bool = (c.Length > 0)
	var hasLen2 bool = (c.Length2 > 0)
	if hasLen2 {
		res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
	} else if hasLen1 {
		res += "(" + strconv.Itoa(c.Length) + ")"
	}
	return res
}
Beispiel #9
0
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
	v, err := col.ValueOf(bean)
	if err != nil {
		return
	}
	if v.CanSet() {
		switch v.Type().Kind() {
		case reflect.Struct:
			v.Set(reflect.ValueOf(t).Convert(v.Type()))
		case reflect.Int, reflect.Int64, reflect.Int32:
			v.SetInt(t.Unix())
		case reflect.Uint, reflect.Uint64, reflect.Uint32:
			v.SetUint(uint64(t.Unix()))
		}
	}
}
Beispiel #10
0
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{tableName}
	s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
	db.LogSQL(s, args)
	rows, err := db.DB().Query(s, args...)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	var name string
	for rows.Next() {
		err = rows.Scan(&name)
		if err != nil {
			return nil, nil, err
		}
		break
	}

	if name == "" {
		return nil, nil, errors.New("no table named " + tableName)
	}

	nStart := strings.Index(name, "(")
	nEnd := strings.LastIndex(name, ")")
	reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`)
	colCreates := reg.FindAllString(name[nStart+1:nEnd], -1)
	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for _, colStr := range colCreates {
		reg = regexp.MustCompile(`,\s`)
		colStr = reg.ReplaceAllString(colStr, ",")
		fields := strings.Fields(strings.TrimSpace(colStr))
		col := new(core.Column)
		col.Indexes = make(map[string]bool)
		col.Nullable = true
		col.DefaultIsEmpty = true
		for idx, field := range fields {
			if idx == 0 {
				col.Name = strings.Trim(field, "`[] ")
				continue
			} else if idx == 1 {
				col.SQLType = core.SQLType{field, 0, 0}
			}
			switch field {
			case "PRIMARY":
				col.IsPrimaryKey = true
			case "AUTOINCREMENT":
				col.IsAutoIncrement = true
			case "NULL":
				if fields[idx-1] == "NOT" {
					col.Nullable = false
				} else {
					col.Nullable = true
				}
			case "DEFAULT":
				col.Default = fields[idx+1]
				col.DefaultIsEmpty = false
			}
		}
		if !col.SQLType.IsNumeric() && !col.DefaultIsEmpty {
			col.Default = "'" + col.Default + "'"
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}
	return colSeq, cols, nil
}
Beispiel #11
0
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{tableName}
	s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
		", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"

	rows, err := db.DB().Query(s, args...)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)

	for rows.Next() {
		col := new(core.Column)
		col.Indexes = make(map[string]bool)

		var colName, isNullable, dataType string
		var maxLenStr, colDefault, numPrecision, numRadix *string
		err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix)
		if err != nil {
			return nil, nil, err
		}

		var maxLen int
		if maxLenStr != nil {
			maxLen, err = strconv.Atoi(*maxLenStr)
			if err != nil {
				return nil, nil, err
			}
		}

		col.Name = strings.Trim(colName, `" `)

		if colDefault != nil {
			if strings.HasPrefix(*colDefault, "nextval") {
				col.IsPrimaryKey = true
			} else {
				col.Default = *colDefault
			}
		}

		if isNullable == "YES" {
			col.Nullable = true
		} else {
			col.Nullable = false
		}

		switch dataType {
		case "character varying", "character":
			col.SQLType = core.SQLType{core.Varchar, 0, 0}
		case "timestamp without time zone":
			col.SQLType = core.SQLType{core.DateTime, 0, 0}
		case "timestamp with time zone":
			col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
		case "double precision":
			col.SQLType = core.SQLType{core.Double, 0, 0}
		case "boolean":
			col.SQLType = core.SQLType{core.Bool, 0, 0}
		case "time without time zone":
			col.SQLType = core.SQLType{core.Time, 0, 0}
		default:
			col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
		}
		if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
			return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
		}

		col.Length = maxLen

		if col.SQLType.IsText() {
			if col.Default != "" {
				col.Default = "'" + col.Default + "'"
			} else {
				if col.DefaultIsEmpty {
					col.Default = "''"
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}

	return colSeq, cols, nil
}
Beispiel #12
0
func (engine *Engine) mapType(v reflect.Value) *core.Table {
	t := v.Type()
	table := engine.newTable()
	if tb, ok := v.Interface().(TableName); ok {
		table.Name = tb.TableName()
	} else {
		if v.CanAddr() {
			if tb, ok = v.Addr().Interface().(TableName); ok {
				table.Name = tb.TableName()
			}
		}
		if table.Name == "" {
			table.Name = engine.TableMapper.Obj2Table(t.Name())
		}
	}

	table.Type = t

	var idFieldColName string
	var err error
	var hasCacheTag, hasNoCacheTag bool

	for i := 0; i < t.NumField(); i++ {
		tag := t.Field(i).Tag

		ormTagStr := tag.Get(engine.TagIdentifier)
		var col *core.Column
		fieldValue := v.Field(i)
		fieldType := fieldValue.Type()

		if ormTagStr != "" {
			col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
				IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)}
			tags := splitTag(ormTagStr)

			if len(tags) > 0 {
				if tags[0] == "-" {
					continue
				}
				if strings.ToUpper(tags[0]) == "EXTENDS" {
					switch fieldValue.Kind() {
					case reflect.Ptr:
						f := fieldValue.Type().Elem()
						if f.Kind() == reflect.Struct {
							fieldPtr := fieldValue
							fieldValue = fieldValue.Elem()
							if !fieldValue.IsValid() || fieldPtr.IsNil() {
								fieldValue = reflect.New(f).Elem()
							}
						}
						fallthrough
					case reflect.Struct:
						parentTable := engine.mapType(fieldValue)
						for _, col := range parentTable.Columns() {
							col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
							table.AddColumn(col)
							for indexName, indexType := range col.Indexes {
								addIndex(indexName, table, col, indexType)
							}
						}
						continue
					default:
						//TODO: warning
					}
				}

				indexNames := make(map[string]int)
				var isIndex, isUnique bool
				var preKey string
				for j, key := range tags {
					k := strings.ToUpper(key)
					switch {
					case k == "<-":
						col.MapType = core.ONLYFROMDB
					case k == "->":
						col.MapType = core.ONLYTODB
					case k == "PK":
						col.IsPrimaryKey = true
						col.Nullable = false
					case k == "NULL":
						if j == 0 {
							col.Nullable = true
						} else {
							col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT")
						}
					// TODO: for postgres how add autoincr?
					/*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"):
					col.IsAutoIncrement = true

					autoStart := k[len("AUTOINCR")+1 : len(k)-1]
					autoStartInt, err := strconv.Atoi(autoStart)
					if err != nil {
						engine.LogError(err)
					}
					col.AutoIncrStart = autoStartInt*/
					case k == "AUTOINCR":
						col.IsAutoIncrement = true
						//col.AutoIncrStart = 1
					case k == "DEFAULT":
						col.Default = tags[j+1]
					case k == "CREATED":
						col.IsCreated = true
					case k == "VERSION":
						col.IsVersion = true
						col.Default = "1"
					case k == "UTC":
						col.TimeZone = time.UTC
					case k == "LOCAL":
						col.TimeZone = time.Local
					case strings.HasPrefix(k, "LOCALE(") && strings.HasSuffix(k, ")"):
						location := k[len("INDEX")+1 : len(k)-1]
						col.TimeZone, err = time.LoadLocation(location)
						if err != nil {
							engine.logger.Error(err)
						}
					case k == "UPDATED":
						col.IsUpdated = true
					case k == "DELETED":
						col.IsDeleted = true
					case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
						indexName := k[len("INDEX")+1 : len(k)-1]
						indexNames[indexName] = core.IndexType
					case k == "INDEX":
						isIndex = true
					case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
						indexName := k[len("UNIQUE")+1 : len(k)-1]
						indexNames[indexName] = core.UniqueType
					case k == "UNIQUE":
						isUnique = true
					case k == "NOTNULL":
						col.Nullable = false
					case k == "CACHE":
						if !hasCacheTag {
							hasCacheTag = true
						}
					case k == "NOCACHE":
						if !hasNoCacheTag {
							hasNoCacheTag = true
						}
					case k == "NOT":
					default:
						if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") {
							if preKey != "DEFAULT" {
								col.Name = key[1 : len(key)-1]
							}
						} else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") {
							fs := strings.Split(k, "(")

							if _, ok := core.SqlTypes[fs[0]]; !ok {
								preKey = k
								continue
							}
							col.SQLType = core.SQLType{Name: fs[0]}
							if fs[0] == core.Enum && fs[1][0] == '\'' { //enum
								options := strings.Split(fs[1][0:len(fs[1])-1], ",")
								col.EnumOptions = make(map[string]int)
								for k, v := range options {
									v = strings.TrimSpace(v)
									v = strings.Trim(v, "'")
									col.EnumOptions[v] = k
								}
							} else if fs[0] == core.Set && fs[1][0] == '\'' { //set
								options := strings.Split(fs[1][0:len(fs[1])-1], ",")
								col.SetOptions = make(map[string]int)
								for k, v := range options {
									v = strings.TrimSpace(v)
									v = strings.Trim(v, "'")
									col.SetOptions[v] = k
								}
							} else {
								fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",")
								if len(fs2) == 2 {
									col.Length, err = strconv.Atoi(fs2[0])
									if err != nil {
										engine.logger.Error(err)
									}
									col.Length2, err = strconv.Atoi(fs2[1])
									if err != nil {
										engine.logger.Error(err)
									}
								} else if len(fs2) == 1 {
									col.Length, err = strconv.Atoi(fs2[0])
									if err != nil {
										engine.logger.Error(err)
									}
								}
							}
						} else {
							if _, ok := core.SqlTypes[k]; ok {
								col.SQLType = core.SQLType{Name: k}
							} else if key != col.Default {
								col.Name = key
							}
						}
						engine.dialect.SqlType(col)
					}
					preKey = k
				}
				if col.SQLType.Name == "" {
					col.SQLType = core.Type2SQLType(fieldType)
				}
				if col.Length == 0 {
					col.Length = col.SQLType.DefaultLength
				}
				if col.Length2 == 0 {
					col.Length2 = col.SQLType.DefaultLength2
				}

				if col.Name == "" {
					col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name)
				}

				if isUnique {
					indexNames[col.Name] = core.UniqueType
				} else if isIndex {
					indexNames[col.Name] = core.IndexType
				}

				for indexName, indexType := range indexNames {
					addIndex(indexName, table, col, indexType)
				}
			}
		} else {
			var sqlType core.SQLType
			if fieldValue.CanAddr() {
				if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
					sqlType = core.SQLType{Name: core.Text}
				}
			}
			if _, ok := fieldValue.Interface().(core.Conversion); ok {
				sqlType = core.SQLType{Name: core.Text}
			} else {
				sqlType = core.Type2SQLType(fieldType)
			}
			col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
				t.Field(i).Name, sqlType, sqlType.DefaultLength,
				sqlType.DefaultLength2, true)
		}
		if col.IsAutoIncrement {
			col.Nullable = false
		}

		table.AddColumn(col)

		if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
			idFieldColName = col.Name
		}
	} // end for

	if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
		col := table.GetColumn(idFieldColName)
		col.IsPrimaryKey = true
		col.IsAutoIncrement = true
		col.Nullable = false
		table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
		table.AutoIncrement = col.Name
	}

	if hasCacheTag {
		if engine.Cacher != nil { // !nash! use engine's cacher if provided
			engine.logger.Info("enable cache on table:", table.Name)
			table.Cacher = engine.Cacher
		} else {
			engine.logger.Info("enable LRU cache on table:", table.Name)
			table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now
		}
	}
	if hasNoCacheTag {
		engine.logger.Info("no cache on table:", table.Name)
		table.Cacher = nil
	}

	return table
}
Beispiel #13
0
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	pgSchema := "public"
	args := []interface{}{tableName, pgSchema}
	s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
    CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
    CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
FROM pg_attribute f
    JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
    LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
    LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
    LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
    LEFT JOIN pg_class AS g ON p.confrelid = g.oid
    LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`

	rows, err := db.DB().Query(s, args...)
	if db.Logger != nil {
		db.Logger.Info("[sql]", s, args)
	}
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)

	for rows.Next() {
		col := new(core.Column)
		col.Indexes = make(map[string]bool)

		var colName, isNullable, dataType string
		var maxLenStr, colDefault, numPrecision, numRadix *string
		var isPK, isUnique bool
		err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix, &isPK, &isUnique)
		if err != nil {
			return nil, nil, err
		}

		//fmt.Println(args, colName, isNullable, dataType, maxLenStr, colDefault, numPrecision, numRadix, isPK, isUnique)
		var maxLen int
		if maxLenStr != nil {
			maxLen, err = strconv.Atoi(*maxLenStr)
			if err != nil {
				return nil, nil, err
			}
		}

		col.Name = strings.Trim(colName, `" `)

		if colDefault != nil || isPK {
			if isPK {
				col.IsPrimaryKey = true
			} else {
				col.Default = *colDefault
			}
		}

		if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") {
			col.IsAutoIncrement = true
		}

		col.Nullable = (isNullable == "YES")

		switch dataType {
		case "character varying", "character":
			col.SQLType = core.SQLType{core.Varchar, 0, 0}
		case "timestamp without time zone":
			col.SQLType = core.SQLType{core.DateTime, 0, 0}
		case "timestamp with time zone":
			col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
		case "double precision":
			col.SQLType = core.SQLType{core.Double, 0, 0}
		case "boolean":
			col.SQLType = core.SQLType{core.Bool, 0, 0}
		case "time without time zone":
			col.SQLType = core.SQLType{core.Time, 0, 0}
		default:
			col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
		}
		if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
			return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
		}

		col.Length = maxLen

		if col.SQLType.IsText() || col.SQLType.IsTime() {
			if col.Default != "" {
				col.Default = "'" + col.Default + "'"
			} else {
				if col.DefaultIsEmpty {
					col.Default = "''"
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}

	return colSeq, cols, nil
}
Beispiel #14
0
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{strings.ToUpper(tableName)}
	s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
		"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"

	rows, err := db.DB().Query(s, args...)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for rows.Next() {
		col := new(core.Column)
		col.Indexes = make(map[string]bool)

		var colName, colDefault, nullable, dataType, dataPrecision, dataScale string
		var dataLen int

		err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
			&dataScale, &nullable)
		if err != nil {
			return nil, nil, err
		}

		col.Name = strings.Trim(colName, `" `)
		col.Default = colDefault

		if nullable == "Y" {
			col.Nullable = true
		} else {
			col.Nullable = false
		}

		switch dataType {
		case "VARCHAR2":
			col.SQLType = core.SQLType{core.Varchar, 0, 0}
		case "TIMESTAMP WITH TIME ZONE":
			col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
		default:
			col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
		}
		if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
			return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
		}

		col.Length = dataLen

		if col.SQLType.IsText() {
			if col.Default != "" {
				col.Default = "'" + col.Default + "'"
			} else {
				if col.DefaultIsEmpty {
					col.Default = "''"
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}

	return colSeq, cols, nil
}
Beispiel #15
0
func (engine *Engine) mapType(v reflect.Value) *core.Table {
	t := v.Type()
	table := engine.newTable()
	method := v.MethodByName("TableName")
	if !method.IsValid() {
		if v.CanAddr() {
			method = v.Addr().MethodByName("TableName")
		}
	}
	if method.IsValid() {
		params := []reflect.Value{}
		results := method.Call(params)
		if len(results) == 1 {
			table.Name = results[0].Interface().(string)
		}
	}

	if table.Name == "" {
		table.Name = engine.TableMapper.Obj2Table(t.Name())
	}
	table.Type = t

	var idFieldColName string
	var err error

	for i := 0; i < t.NumField(); i++ {
		tag := t.Field(i).Tag

		ormTagStr := tag.Get(engine.TagIdentifier)
		var col *core.Column
		fieldValue := v.Field(i)
		fieldType := fieldValue.Type()

		if ormTagStr != "" {
			col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
				IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)}
			tags := strings.Split(ormTagStr, " ")

			if len(tags) > 0 {
				if tags[0] == "-" {
					continue
				}
				if strings.ToUpper(tags[0]) == "EXTENDS" {
					if fieldValue.Kind() == reflect.Struct {
						parentTable := engine.mapType(fieldValue)
						for _, col := range parentTable.Columns() {
							col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
							table.AddColumn(col)
						}

						continue
					} else if fieldValue.Kind() == reflect.Ptr {
						f := fieldValue.Type().Elem()
						if f.Kind() == reflect.Struct {
							fieldValue = fieldValue.Elem()
							if !fieldValue.IsValid() || fieldValue.IsNil() {
								fieldValue = reflect.New(f).Elem()
							}
						}

						parentTable := engine.mapType(fieldValue)
						for _, col := range parentTable.Columns() {
							col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
							table.AddColumn(col)
						}

						continue
					}
					//TODO: warning
				}

				indexNames := make(map[string]int)
				var isIndex, isUnique bool
				var preKey string
				for j, key := range tags {
					k := strings.ToUpper(key)
					switch {
					case k == "<-":
						col.MapType = core.ONLYFROMDB
					case k == "->":
						col.MapType = core.ONLYTODB
					case k == "PK":
						col.IsPrimaryKey = true
						col.Nullable = false
					case k == "NULL":
						col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT")
					/*case strings.HasPrefix(k, "AUTOINCR(") && strings.HasSuffix(k, ")"):
					col.IsAutoIncrement = true

					autoStart := k[len("AUTOINCR")+1 : len(k)-1]
					autoStartInt, err := strconv.Atoi(autoStart)
					if err != nil {
						engine.LogError(err)
					}
					col.AutoIncrStart = autoStartInt*/
					case k == "AUTOINCR":
						col.IsAutoIncrement = true
						//col.AutoIncrStart = 1
					case k == "DEFAULT":
						col.Default = tags[j+1]
					case k == "CREATED":
						col.IsCreated = true
					case k == "VERSION":
						col.IsVersion = true
						col.Default = "1"
					case k == "UPDATED":
						col.IsUpdated = true
					case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
						indexName := k[len("INDEX")+1 : len(k)-1]
						indexNames[indexName] = core.IndexType
					case k == "INDEX":
						isIndex = true
					case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
						indexName := k[len("UNIQUE")+1 : len(k)-1]
						indexNames[indexName] = core.UniqueType
					case k == "UNIQUE":
						isUnique = true
					case k == "NOTNULL":
						col.Nullable = false
					case k == "NOT":
					default:
						if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") {
							if preKey != "DEFAULT" {
								col.Name = key[1 : len(key)-1]
							}
						} else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") {
							fs := strings.Split(k, "(")

							if _, ok := core.SqlTypes[fs[0]]; !ok {
								preKey = k
								continue
							}
							col.SQLType = core.SQLType{fs[0], 0, 0}
							if fs[0] == core.Enum && fs[1][0] == '\'' { //enum
								options := strings.Split(fs[1][0:len(fs[1])-1], ",")
								col.EnumOptions = make(map[string]int)
								for k, v := range options {
									v = strings.TrimSpace(v)
									v = strings.Trim(v, "'")
									col.EnumOptions[v] = k
								}
							} else {
								fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",")
								if len(fs2) == 2 {
									col.Length, err = strconv.Atoi(fs2[0])
									if err != nil {
										engine.LogError(err)
									}
									col.Length2, err = strconv.Atoi(fs2[1])
									if err != nil {
										engine.LogError(err)
									}
								} else if len(fs2) == 1 {
									col.Length, err = strconv.Atoi(fs2[0])
									if err != nil {
										engine.LogError(err)
									}
								}
							}
						} else {
							if _, ok := core.SqlTypes[k]; ok {
								col.SQLType = core.SQLType{k, 0, 0}
							} else if key != col.Default {
								col.Name = key
							}
						}
						engine.dialect.SqlType(col)
					}
					preKey = k
				}
				if col.SQLType.Name == "" {
					col.SQLType = core.Type2SQLType(fieldType)
				}
				if col.Length == 0 {
					col.Length = col.SQLType.DefaultLength
				}
				if col.Length2 == 0 {
					col.Length2 = col.SQLType.DefaultLength2
				}

				if col.Name == "" {
					col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name)
				}

				if isUnique {
					indexNames[col.Name] = core.UniqueType
				} else if isIndex {
					indexNames[col.Name] = core.IndexType
				}

				for indexName, indexType := range indexNames {
					addIndex(indexName, table, col, indexType)
				}
			}
		} else {
			var sqlType core.SQLType
			if fieldValue.CanAddr() {
				if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
					sqlType = core.SQLType{core.Text, 0, 0}
				}
			}
			if _, ok := fieldValue.Interface().(core.Conversion); ok {
				sqlType = core.SQLType{core.Text, 0, 0}
			} else {
				sqlType = core.Type2SQLType(fieldType)
			}
			col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
				t.Field(i).Name, sqlType, sqlType.DefaultLength,
				sqlType.DefaultLength2, true)
		}
		if col.IsAutoIncrement {
			col.Nullable = false
		}

		table.AddColumn(col)

		if fieldType.Kind() == reflect.Int64 && (col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id")) {
			idFieldColName = col.Name
		}
	}

	if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
		col := table.GetColumn(idFieldColName)
		col.IsPrimaryKey = true
		col.IsAutoIncrement = true
		col.Nullable = false
		table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
		table.AutoIncrement = col.Name
	}

	return table
}
Beispiel #16
0
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{tableName}
	s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
		"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"

	rows, err := db.DB().Query(s, args...)
	if db.Logger != nil {
		db.Logger.Info("[sql]", s, args)
	}
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for rows.Next() {
		col := new(core.Column)
		col.Indexes = make(map[string]bool)

		var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
		var dataLen int

		err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
			&dataScale, &nullable)
		if err != nil {
			return nil, nil, err
		}

		col.Name = strings.Trim(*colName, `" `)
		if colDefault != nil {
			col.Default = *colDefault
			col.DefaultIsEmpty = false
		}

		if *nullable == "Y" {
			col.Nullable = true
		} else {
			col.Nullable = false
		}

		var ignore bool

		var dt string
		var len1, len2 int
		dts := strings.Split(*dataType, "(")
		dt = dts[0]
		if len(dts) > 1 {
			lens := strings.Split(dts[1][:len(dts[1])-1], ",")
			if len(lens) > 1 {
				len1, _ = strconv.Atoi(lens[0])
				len2, _ = strconv.Atoi(lens[1])
			} else {
				len1, _ = strconv.Atoi(lens[0])
			}
		}

		switch dt {
		case "VARCHAR2":
			col.SQLType = core.SQLType{core.Varchar, len1, len2}
		case "NVARCHAR2":
			col.SQLType = core.SQLType{core.NVarchar, len1, len2}
		case "TIMESTAMP WITH TIME ZONE":
			col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
		case "NUMBER":
			col.SQLType = core.SQLType{core.Double, len1, len2}
		case "LONG", "LONG RAW":
			col.SQLType = core.SQLType{core.Text, 0, 0}
		case "RAW":
			col.SQLType = core.SQLType{core.Binary, 0, 0}
		case "ROWID":
			col.SQLType = core.SQLType{core.Varchar, 18, 0}
		case "AQ$_SUBSCRIBERS":
			ignore = true
		default:
			col.SQLType = core.SQLType{strings.ToUpper(dt), len1, len2}
		}

		if ignore {
			continue
		}

		if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
			return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v %v", *dataType, col.SQLType))
		}

		col.Length = dataLen

		if col.SQLType.IsText() || col.SQLType.IsTime() {
			if !col.DefaultIsEmpty {
				col.Default = "'" + col.Default + "'"
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}

	return colSeq, cols, nil
}
Beispiel #17
0
func (db *tidb) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
	args := []interface{}{db.DbName, tableName}
	s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
		" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"

	rows, err := db.DB().Query(s, args...)
	db.LogSQL(s, args)

	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	cols := make(map[string]*core.Column)
	colSeq := make([]string, 0)
	for rows.Next() {
		col := new(core.Column)
		col.Indexes = make(map[string]int)

		var columnName, isNullable, colType, colKey, extra string
		var colDefault *string
		err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra)
		if err != nil {
			return nil, nil, err
		}
		col.Name = strings.Trim(columnName, "` ")
		if "YES" == isNullable {
			col.Nullable = true
		}

		if colDefault != nil {
			col.Default = *colDefault
			if col.Default == "" {
				col.DefaultIsEmpty = true
			}
		}

		cts := strings.Split(colType, "(")
		colName := cts[0]
		colType = strings.ToUpper(colName)
		var len1, len2 int
		if len(cts) == 2 {
			idx := strings.Index(cts[1], ")")
			if colType == core.Enum && cts[1][0] == '\'' { //enum
				options := strings.Split(cts[1][0:idx], ",")
				col.EnumOptions = make(map[string]int)
				for k, v := range options {
					v = strings.TrimSpace(v)
					v = strings.Trim(v, "'")
					col.EnumOptions[v] = k
				}
			} else if colType == core.Set && cts[1][0] == '\'' {
				options := strings.Split(cts[1][0:idx], ",")
				col.SetOptions = make(map[string]int)
				for k, v := range options {
					v = strings.TrimSpace(v)
					v = strings.Trim(v, "'")
					col.SetOptions[v] = k
				}
			} else {
				lens := strings.Split(cts[1][0:idx], ",")
				len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
				if err != nil {
					return nil, nil, err
				}
				if len(lens) == 2 {
					len2, err = strconv.Atoi(lens[1])
					if err != nil {
						return nil, nil, err
					}
				}
			}
		}
		if colType == "FLOAT UNSIGNED" {
			colType = "FLOAT"
		}
		col.Length = len1
		col.Length2 = len2
		if _, ok := core.SqlTypes[colType]; ok {
			col.SQLType = core.SQLType{colType, len1, len2}
		} else {
			return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
		}

		if colKey == "PRI" {
			col.IsPrimaryKey = true
		}
		if colKey == "UNI" {
			//col.is
		}

		if extra == "auto_increment" {
			col.IsAutoIncrement = true
		}

		if col.SQLType.IsText() || col.SQLType.IsTime() {
			if col.Default != "" {
				col.Default = "'" + col.Default + "'"
			} else {
				if col.DefaultIsEmpty {
					col.Default = "''"
				}
			}
		}
		cols[col.Name] = col
		colSeq = append(colSeq, col.Name)
	}
	return colSeq, cols, nil
}
func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
	quote := s.Engine.Quote
	sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()),
		col.String(s.Engine.dialect))
	return sql, []interface{}{}
}
Beispiel #19
0
func (this *databaseImplement) UpdateBatch(rowsSlicePtr interface{}, indexColName string) (int64, error) {
	sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
	if sliceValue.Kind() != reflect.Slice {
		return 0, errors.New("needs a pointer to a slice")
	}
	if sliceValue.Len() == 0 {
		return 0, errors.New("update rows is empty")
	}

	bean := sliceValue.Index(0).Interface()
	elementValue := this.rValue(bean)
	table := this.autoMapType(elementValue)
	size := sliceValue.Len()

	var rows = make([][]interface{}, 0)
	var indexRow = make([]interface{}, 0)
	cols := make([]*core.Column, 0)
	updateCols := make([]bool, 0)
	var indexCol *core.Column

	//提取字段
	for i := 0; i < size; i++ {
		v := sliceValue.Index(i)
		vv := reflect.Indirect(v)

		//处理需要的update的列
		if i == 0 {
			for _, col := range table.Columns() {
				if col.Name == indexColName {
					indexCol = col
				} else {
					cols = append(cols, col)
					updateCols = append(updateCols, false)
				}
			}
			if indexCol == nil {
				return 0, errors.New("counld not found index col " + indexColName)
			}
		}

		//处理需要的update的值
		var singleRow = make([]interface{}, 0)
		for colIndex, col := range cols {
			ptrFieldValue, err := col.ValueOfV(&vv)
			if err != nil {
				return 0, err
			}
			fieldValue := *ptrFieldValue
			var arg interface{}
			if this.isZero(fieldValue.Interface()) {
				arg = nil
			} else {
				var err error
				arg, err = this.value2Interface(fieldValue)
				if err != nil {
					return 0, err
				}
				updateCols[colIndex] = true
			}
			singleRow = append(singleRow, arg)
		}
		rows = append(rows, singleRow)
		ptrFieldValue, err := indexCol.ValueOfV(&vv)
		if err != nil {
			return 0, err
		}
		fieldValue := *ptrFieldValue
		arg, err := this.value2Interface(fieldValue)
		if err != nil {
			return 0, err
		}
		indexRow = append(indexRow, arg)
	}
	if len(cols) == 0 {
		return 0, errors.New("update cols is empty! " + fmt.Sprintf("%v", rowsSlicePtr))
	}

	//拼接sql
	var sqlArgs = make([]interface{}, 0)
	var sql = "UPDATE " + table.Name + " SET "
	var isFirstUpdateCol = true
	for colIndex, col := range cols {
		if updateCols[colIndex] == false {
			continue
		}
		if isFirstUpdateCol == false {
			sql += " , "
		}
		sql += this.Engine.QuoteStr() + col.Name + this.Engine.QuoteStr()
		sql += " = CASE "
		sql += this.Engine.QuoteStr() + indexCol.Name + this.Engine.QuoteStr()
		for rowIndex, row := range rows {
			if row[colIndex] == nil {
				continue
			}
			sql += " WHEN ? THEN ? "
			sqlArgs = append(sqlArgs, indexRow[rowIndex])
			sqlArgs = append(sqlArgs, row[colIndex])
		}
		sql += " END "
		isFirstUpdateCol = false
	}
	sql += " WHERE " + this.Engine.QuoteStr() + indexCol.Name + this.Engine.QuoteStr() + " IN ( "
	for rowIndex, row := range indexRow {
		if rowIndex != 0 {
			sql += " , "
		}
		sql += " ? "
		sqlArgs = append(sqlArgs, row)
	}
	sql += " ) "

	//执行sql
	res, err := this.Exec(sql, sqlArgs...)
	if err != nil {
		return 0, err
	}
	return res.RowsAffected()
}