예제 #1
0
func (c *ClientConn) handleAdmin(admin *sqlparser.Admin) error {
	var err error
	var result *mysql.Resultset

	region := sqlparser.String(admin.Region)

	err = c.checkCmdOrder(region, admin.Columns)
	if err != nil {
		return err
	}

	switch strings.ToLower(region) {
	case NodeRegion:
		err = c.handleNodeCmd(admin.Rows)
	case ServerRegion:
		result, err = c.handleServerCmd(admin.Rows)
	default:
		return fmt.Errorf("admin %s not supported now", region)
	}

	if err != nil {
		golog.Error("ClientConn", "handleAdmin", err.Error(),
			c.connectionId, "sql", sqlparser.String(admin))
		return err
	}

	if result != nil {
		return c.writeResultset(c.status, result)
	}

	return c.writeOK(nil)
}
예제 #2
0
func (c *ClientConn) handleSetNames(ch, ci sqlparser.ValExpr) error {
	var cid mysql.CollationId
	var ok bool

	value := sqlparser.String(ch)
	value = strings.Trim(value, "'`\"")

	charset := strings.ToLower(value)
	if charset == "null" {
		return c.writeOK(nil)
	}
	if ci == nil {
		cid, ok = mysql.CharsetIds[charset]
		if !ok {
			return fmt.Errorf("invalid charset %s", charset)
		}
	} else {
		collate := sqlparser.String(ci)
		collate = strings.Trim(value, "'`\"")
		cid, ok = mysql.CollationNames[collate]
		if !ok {
			return fmt.Errorf("invalid charset %s", charset)
		}
	}
	c.charset = charset
	c.collation = cid

	return c.writeOK(nil)
}
예제 #3
0
func (c *ClientConn) handleNodeCmd(rows sqlparser.InsertRows) error {
	var err error
	var opt, nodeName, role, addr string

	vals := rows.(sqlparser.Values)
	if len(vals) == 0 {
		return errors.ErrCmdUnsupport
	}

	tuple := vals[0].(sqlparser.ValTuple)
	if len(tuple) != len(cmdNodeOrder) {
		return errors.ErrCmdUnsupport
	}

	opt = sqlparser.String(tuple[0])
	opt = strings.Trim(opt, "'")

	nodeName = sqlparser.String(tuple[1])
	nodeName = strings.Trim(nodeName, "'")

	role = sqlparser.String(tuple[2])
	role = strings.Trim(role, "'")

	addr = sqlparser.String(tuple[3])
	addr = strings.Trim(addr, "'")

	switch strings.ToLower(opt) {
	case ADMIN_OPT_ADD:
		err = c.AddDatabase(
			nodeName,
			role,
			addr,
		)
	case ADMIN_OPT_DEL:
		err = c.DeleteDatabase(
			nodeName,
			role,
			addr,
		)

	case ADMIN_OPT_UP:
		err = c.UpDatabase(
			nodeName,
			role,
			addr,
		)
	case ADMIN_OPT_DOWN:
		err = c.DownDatabase(
			nodeName,
			role,
			addr,
		)
	default:
		err = errors.ErrCmdUnsupport
		golog.Error("ClientConn", "handleNodeCmd", err.Error(),
			c.connectionId, "opt", opt)
	}
	return err
}
예제 #4
0
파일: router.go 프로젝트: jin06/kingshard
/*生成一个route plan*/
func (r *Router) GetPlan(statement sqlparser.Statement) (plan *Plan) {
	plan = &Plan{}
	var where *sqlparser.Where
	//因为实现Statement接口的方法都是指针类型,所以type对应类型也是指针类型
	switch stmt := statement.(type) {
	case *sqlparser.Insert:
		if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
			panic(sqlparser.NewParserError("select in insert not allowed"))
		}
		/*根据sql语句的表,获得对应的分片规则*/
		plan.rule = r.GetRule(sqlparser.String(stmt.Table))

		if stmt.OnDup != nil {
			plan.rule.checkUpdateExprs(sqlparser.UpdateExprs(stmt.OnDup))
		}

		plan.criteria = plan.routingAnalyzeValues(stmt.Rows.(sqlparser.Values))
		plan.fullList = makeList(0, len(plan.rule.Nodes))
		return plan
	case *sqlparser.Replace:
		if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
			panic(sqlparser.NewParserError("select in replace not allowed"))
		}

		plan.rule = r.GetRule(sqlparser.String(stmt.Table))
		plan.criteria = plan.routingAnalyzeValues(stmt.Rows.(sqlparser.Values))
		plan.fullList = makeList(0, len(plan.rule.Nodes))
		return plan

	case *sqlparser.Select:
		plan.rule = r.GetRule(sqlparser.String(stmt.From[0])) //根据表名获得分表规则
		where = stmt.Where
	case *sqlparser.Update:
		plan.rule = r.GetRule(sqlparser.String(stmt.Table))

		plan.rule.checkUpdateExprs(stmt.Exprs)

		where = stmt.Where
	case *sqlparser.Delete:
		plan.rule = r.GetRule(sqlparser.String(stmt.Table))
		where = stmt.Where
	}

	if where != nil {
		plan.criteria = where.Expr /*路由条件*/
	} else {
		plan.rule = r.DefaultRule
	}
	plan.fullList = makeList(0, len(plan.rule.Nodes))

	return plan
}
예제 #5
0
func (r *Router) buildSelectPlan(statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	var where *sqlparser.Where
	var err error
	var tableName string

	stmt := statement.(*sqlparser.Select)
	switch v := (stmt.From[0]).(type) {
	case *sqlparser.AliasedTableExpr:
		tableName = sqlparser.String(v.Expr)
	case *sqlparser.JoinTableExpr:
		if ate, ok := (v.LeftExpr).(*sqlparser.AliasedTableExpr); ok {
			tableName = sqlparser.String(ate.Expr)
		} else {
			tableName = sqlparser.String(v)
		}
	default:
		tableName = sqlparser.String(v)
	}

	plan.Rule = r.GetRule(tableName) //根据表名获得分表规则
	where = stmt.Where
	plan.TableIndexs = makeList(0, len(plan.Rule.TableToNode))

	if where != nil {
		plan.Criteria = where.Expr //路由条件
		err = plan.calRouteIndexs()
		if err != nil {
			golog.Error("Route", "BuildSelectPlan", err.Error(), 0)
			return nil, err
		}
	} else {
		//if shard select without where,send to all nodes and all tables
		plan.RouteTableIndexs = plan.TableIndexs
		plan.RouteNodeIndexs = makeList(0, len(plan.Rule.Nodes))
	}

	if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 {
		golog.Error("Route", "BuildSelectPlan", errors.ErrNoCriteria.Error(), 0)
		return nil, errors.ErrNoCriteria
	}
	//generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node
	err = r.generateSelectSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #6
0
파일: router.go 프로젝트: velsai/kingshard
func (r *Router) buildInsertPlan(statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	stmt := statement.(*sqlparser.Insert)
	if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
		return nil, errors.ErrSelectInInsert
	}
	/*根据sql语句的表,获得对应的分片规则*/
	plan.Rule = r.GetRule(sqlparser.String(stmt.Table))

	if stmt.OnDup != nil {
		err := plan.Rule.checkUpdateExprs(sqlparser.UpdateExprs(stmt.OnDup))
		if err != nil {
			return nil, err
		}
	}

	plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values))
	plan.TableIndexs = makeList(0, len(plan.Rule.TableToNode))

	err := plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildInsertPlan", err.Error(), 0)
		return nil, err
	}

	err = r.generateInsertSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #7
0
파일: router.go 프로젝트: flike/kingshard
func (r *Router) buildReplacePlan(db string, statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	plan.Rows = make(map[int]sqlparser.Values)

	stmt := statement.(*sqlparser.Replace)
	if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
		panic(sqlparser.NewParserError("select in replace not allowed"))
	}

	if stmt.Columns == nil {
		return nil, errors.ErrIRNoColumns
	}

	plan.Rule = r.GetRule(db, sqlparser.String(stmt.Table))

	err := plan.GetIRKeyIndex(stmt.Columns)
	if err != nil {
		return nil, err
	}

	plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values))

	err = plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildReplacePlan", err.Error(), 0)
		return nil, err
	}

	err = r.generateReplaceSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #8
0
파일: router.go 프로젝트: flike/kingshard
func (r *Router) buildDeletePlan(db string, statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	var where *sqlparser.Where
	var err error

	stmt := statement.(*sqlparser.Delete)
	plan.Rule = r.GetRule(db, sqlparser.String(stmt.Table))
	where = stmt.Where

	if where != nil {
		plan.Criteria = where.Expr //路由条件
		err = plan.calRouteIndexs()
		if err != nil {
			golog.Error("Route", "BuildUpdatePlan", err.Error(), 0)
			return nil, err
		}
	} else {
		//if shard delete without where,send to all nodes and all tables
		plan.RouteTableIndexs = plan.Rule.SubTableIndexs
		plan.RouteNodeIndexs = makeList(0, len(plan.Rule.Nodes))
	}

	if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 {
		golog.Error("Route", "BuildDeletePlan", errors.ErrNoCriteria.Error(), 0)
		return nil, errors.ErrNoCriteria
	}
	//generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node
	err = r.generateDeleteSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #9
0
파일: conn_set.go 프로젝트: flike/kingshard
func (c *ClientConn) handleSetAutoCommit(val sqlparser.ValExpr) error {
	flag := sqlparser.String(val)
	flag = strings.Trim(flag, "'`\"")
	// autocommit允许为 0, 1, ON, OFF, "ON", "OFF", 不允许"0", "1"
	if flag == `0` || flag == `1` {
		_, ok := val.(sqlparser.NumVal)
		if !ok {
			return fmt.Errorf("set autocommit error")
		}
	}
	switch strings.ToUpper(flag) {
	case `1`, `ON`:
		c.status |= mysql.SERVER_STATUS_AUTOCOMMIT
		if c.status&mysql.SERVER_STATUS_IN_TRANS > 0 {
			c.status &= ^mysql.SERVER_STATUS_IN_TRANS
		}
		for _, co := range c.txConns {
			if e := co.SetAutoCommit(1); e != nil {
				co.Close()
				c.txConns = make(map[*backend.Node]*backend.BackendConn)
				return fmt.Errorf("set autocommit error, %v", e)
			}
			co.Close()
		}
		c.txConns = make(map[*backend.Node]*backend.BackendConn)
	case `0`, `OFF`:
		c.status &= ^mysql.SERVER_STATUS_AUTOCOMMIT
	default:
		return fmt.Errorf("invalid autocommit flag %s", flag)
	}

	return c.writeOK(nil)
}
예제 #10
0
파일: router.go 프로젝트: velsai/kingshard
func (r *Router) buildDeletePlan(statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	var where *sqlparser.Where

	stmt := statement.(*sqlparser.Delete)
	plan.Rule = r.GetRule(sqlparser.String(stmt.Table))
	where = stmt.Where

	if where != nil {
		plan.Criteria = where.Expr /*路由条件*/
	} else {
		plan.Rule = r.DefaultRule
	}

	plan.TableIndexs = makeList(0, len(plan.Rule.TableToNode))

	err := plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildDeletePlan", err.Error(), 0)
		return nil, err
	}

	if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 {
		golog.Error("Route", "BuildDeletePlan", errors.ErrNoCriteria.Error(), 0)
		return nil, errors.ErrNoCriteria
	}
	//generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node
	err = r.generateDeleteSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #11
0
파일: router.go 프로젝트: velsai/kingshard
func (r *Router) buildReplacePlan(statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}

	stmt := statement.(*sqlparser.Replace)
	if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
		panic(sqlparser.NewParserError("select in replace not allowed"))
	}

	plan.Rule = r.GetRule(sqlparser.String(stmt.Table))
	plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values))

	plan.TableIndexs = makeList(0, len(plan.Rule.TableToNode))

	err := plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildReplacePlan", err.Error(), 0)
		return nil, err
	}

	err = r.generateReplaceSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #12
0
func (c *ClientConn) handleUseDB(stmt *sqlparser.UseDB) error {
	if len(stmt.DB) == 0 {
		return fmt.Errorf("must have database, not %s", sqlparser.String(stmt))
	}
	c.db = string(stmt.DB)
	return c.writeOK(nil)
}
예제 #13
0
func (c *ClientConn) handleServerCmd(rows sqlparser.InsertRows) (*mysql.Resultset, error) {
	var err error
	var result *mysql.Resultset
	var opt, k, v string

	vals := rows.(sqlparser.Values)
	if len(vals) == 0 {
		return nil, errors.ErrCmdUnsupport
	}

	tuple := vals[0].(sqlparser.ValTuple)
	if len(tuple) != len(cmdServerOrder) {
		return nil, errors.ErrCmdUnsupport
	}

	opt = sqlparser.String(tuple[0])
	opt = strings.Trim(opt, "'")

	k = sqlparser.String(tuple[1])
	k = strings.Trim(k, "'")

	v = sqlparser.String(tuple[2])
	v = strings.Trim(v, "'")

	switch strings.ToLower(opt) {
	case ADMIN_OPT_SHOW:
		result, err = c.handleAdminShow(k, v)
	case ADMIN_OPT_CHANGE:
		err = c.handleAdminChange(k, v)
	case ADMIN_OPT_ADD:
		err = c.handleAdminAdd(k, v)
	case ADMIN_OPT_DEL:
		err = c.handleAdminDelete(k, v)
	case ADMIN_SAVE_CONFIG:
		err = c.handleAdminSave(k, v)
	default:
		err = errors.ErrCmdUnsupport
		golog.Error("ClientConn", "handleNodeCmd", err.Error(),
			c.connectionId, "opt", opt)
	}
	if err != nil {
		return nil, err
	}

	return result, nil
}
예제 #14
0
파일: router.go 프로젝트: velsai/kingshard
func (r *Router) buildSelectPlan(statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	var where *sqlparser.Where
	var tableName string
	stmt := statement.(*sqlparser.Select)
	if ate, ok := (stmt.From[0]).(*sqlparser.AliasedTableExpr); ok {
		tableName = sqlparser.String(ate.Expr)
	} else {
		tableName = sqlparser.String(stmt.From[0])
	}
	plan.Rule = r.GetRule(tableName) //根据表名获得分表规则
	where = stmt.Where

	if where != nil {
		plan.Criteria = where.Expr /*路由条件*/
	} else {
		plan.Rule = r.DefaultRule
	}
	plan.TableIndexs = makeList(0, len(plan.Rule.TableToNode))

	err := plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildSelectPlan", err.Error(), 0)
		return nil, err
	}

	if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 {
		golog.Error("Route", "BuildSelectPlan", errors.ErrNoCriteria.Error(), 0)
		return nil, errors.ErrNoCriteria
	}
	//generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node
	err = r.generateSelectSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #15
0
파일: router.go 프로젝트: snower/kingshard
func (r *Router) generateInsertSql(plan *Plan, stmt sqlparser.Statement) error {
	sqls := make(map[string][]string)
	node, ok := stmt.(*sqlparser.Insert)
	if ok == false {
		return errors.ErrStmtConvert
	}
	if len(plan.RouteNodeIndexs) == 0 {
		return errors.ErrNoRouteNode
	}
	if len(plan.RouteTableIndexs) == 0 {
		buf := sqlparser.NewTrackedBuffer(nil)
		stmt.Format(buf)
		nodeName := r.Nodes[0]
		sqls[nodeName] = []string{buf.String()}
	} else {
		tableCount := len(plan.RouteTableIndexs)
		for i := 0; i < tableCount; i++ {
			buf := sqlparser.NewTrackedBuffer(nil)
			tableIndex := plan.RouteTableIndexs[i]
			nodeIndex := plan.Rule.TableToNode[tableIndex]
			nodeName := r.Nodes[nodeIndex]

			buf.Fprintf("insert %vinto ", node.Comments)
			table := sqlparser.String(node.Table)
			if table[len(table)-1] == '`' {
				fmt.Fprintf(buf, "%s_%04d`", table[:len(table)-1], plan.RouteTableIndexs[i])
			} else {
				fmt.Fprintf(buf, "%s_%04d", table, plan.RouteTableIndexs[i])
			}

			buf.Fprintf("%v %v%v",
				node.Columns,
				plan.Rows[tableIndex],
				node.OnDup)

			if _, ok := sqls[nodeName]; ok == false {
				sqls[nodeName] = make([]string, 0, tableCount)
			}
			sqls[nodeName] = append(sqls[nodeName], buf.String())
		}

	}
	plan.RewrittenSqls = sqls
	return nil
}
예제 #16
0
func (c *ClientConn) handleSetNames(val sqlparser.ValExpr) error {
	value := sqlparser.String(val)
	value = strings.Trim(value, "'`\"")

	charset := strings.ToLower(value)
	if charset == "null" {
		return c.writeOK(nil)
	}
	cid, ok := mysql.CharsetIds[charset]
	if !ok {
		return fmt.Errorf("invalid charset %s", charset)
	}

	c.charset = charset
	c.collation = cid

	return c.writeOK(nil)
}
예제 #17
0
func (c *ClientConn) checkCmdOrder(region string, columns sqlparser.Columns) error {
	var cmdOrder []string
	node := sqlparser.SelectExprs(columns)

	switch region {
	case NodeRegion:
		cmdOrder = cmdNodeOrder
	default:
		return ErrCmdUnsupport
	}

	for i := 0; i < len(node); i++ {
		val := sqlparser.String(node[i])
		if val != cmdOrder[i] {
			return ErrCmdUnsupport
		}
	}

	return nil
}
예제 #18
0
파일: router.go 프로젝트: flike/kingshard
func (r *Router) buildInsertPlan(db string, statement sqlparser.Statement) (*Plan, error) {
	plan := &Plan{}
	plan.Rows = make(map[int]sqlparser.Values)
	stmt := statement.(*sqlparser.Insert)
	if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok {
		return nil, errors.ErrSelectInInsert
	}

	if stmt.Columns == nil {
		return nil, errors.ErrIRNoColumns
	}

	//根据sql语句的表,获得对应的分片规则
	plan.Rule = r.GetRule(db, sqlparser.String(stmt.Table))

	err := plan.GetIRKeyIndex(stmt.Columns)
	if err != nil {
		return nil, err
	}

	if stmt.OnDup != nil {
		err := plan.Rule.checkUpdateExprs(sqlparser.UpdateExprs(stmt.OnDup))
		if err != nil {
			return nil, err
		}
	}

	plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values))

	err = plan.calRouteIndexs()
	if err != nil {
		golog.Error("Route", "BuildInsertPlan", err.Error(), 0)
		return nil, err
	}

	err = r.generateInsertSql(plan, stmt)
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #19
0
func (c *ClientConn) handleUseDB(stmt *sqlparser.UseDB) error {
	if len(stmt.DB) == 0 {
		return fmt.Errorf("must have database, not %s", sqlparser.String(stmt))
	}
	if c.schema == nil {
		return mysql.NewDefaultError(mysql.ER_NO_DB_ERROR)
	}

	nodeName := c.schema.rule.DefaultRule.Nodes[0]

	n := c.proxy.GetNode(nodeName)
	co, err := n.GetMasterConn()
	defer c.closeConn(co, false)
	if err != nil {
		return err
	}

	if err = co.UseDB(string(stmt.DB)); err != nil {
		return err
	}
	c.db = string(stmt.DB)
	return c.writeOK(nil)
}
예제 #20
0
파일: router.go 프로젝트: velsai/kingshard
func (r *Router) generateSelectSql(plan *Plan, stmt sqlparser.Statement) error {
	sqls := make(map[string][]string)
	node, ok := stmt.(*sqlparser.Select)
	if ok == false {
		return errors.ErrStmtConvert
	}
	if len(plan.RouteNodeIndexs) == 0 {
		return errors.ErrNoRouteNode
	}
	if len(plan.RouteTableIndexs) == 0 {
		buf := sqlparser.NewTrackedBuffer(nil)
		stmt.Format(buf)
		nodeName := r.Nodes[0]
		sqls[nodeName] = []string{buf.String()}
	} else {
		tableCount := len(plan.RouteTableIndexs)
		for i := 0; i < tableCount; i++ {
			buf := sqlparser.NewTrackedBuffer(nil)

			buf.Fprintf("select %v%s%v from ",
				node.Comments,
				node.Distinct,
				node.SelectExprs,
			)
			if ate, ok := (node.From[0]).(*sqlparser.AliasedTableExpr); ok {
				if len(ate.As) != 0 {
					fmt.Fprintf(buf, "%s_%04d AS %s",
						sqlparser.String(ate.Expr),
						plan.RouteTableIndexs[i],
						string(ate.As),
					)
				} else {
					fmt.Fprintf(buf, "%s_%04d",
						sqlparser.String(ate.Expr),
						plan.RouteTableIndexs[i],
					)
				}
			} else {
				fmt.Fprintf(buf, "%s_%04d",
					sqlparser.String(node.From[0]),
					plan.RouteTableIndexs[i],
				)
			}
			buf.Fprintf("%v%v%v%v%v%s",
				node.Where,
				node.GroupBy,
				node.Having,
				node.OrderBy,
				node.Limit,
				node.Lock,
			)

			tableIndex := plan.RouteTableIndexs[i]
			nodeIndex := plan.Rule.TableToNode[tableIndex]
			nodeName := r.Nodes[nodeIndex]
			if _, ok := sqls[nodeName]; ok == false {
				sqls[nodeName] = make([]string, 0, tableCount)
			}
			sqls[nodeName] = append(sqls[nodeName], buf.String())
		}

	}
	plan.RewrittenSqls = sqls
	return nil
}
예제 #21
0
파일: router.go 프로젝트: xww/kingshard
//rewrite select sql
func (r *Router) rewriteSelectSql(plan *Plan, node *sqlparser.Select, tableIndex int) string {
	buf := sqlparser.NewTrackedBuffer(nil)
	buf.Fprintf("select %v%s%v",
		node.Comments,
		node.Distinct,
		node.SelectExprs,
	)
	//insert the group columns in the first of select cloumns
	if len(node.GroupBy) != 0 {
		prefix := ","
		for _, n := range node.GroupBy {
			buf.Fprintf("%s%v", prefix, n)
		}
	}
	buf.Fprintf(" from ")

	switch v := (node.From[0]).(type) {
	case *sqlparser.AliasedTableExpr:
		if len(v.As) != 0 {
			fmt.Fprintf(buf, "%s_%04d AS %s",
				sqlparser.String(v.Expr),
				tableIndex,
				string(v.As),
			)
		} else {
			fmt.Fprintf(buf, "%s_%04d",
				sqlparser.String(v.Expr),
				tableIndex,
			)
		}
	case *sqlparser.JoinTableExpr:
		if ate, ok := (v.LeftExpr).(*sqlparser.AliasedTableExpr); ok {
			if len(ate.As) != 0 {
				fmt.Fprintf(buf, "%s_%04d AS %s",
					sqlparser.String(ate.Expr),
					tableIndex,
					string(ate.As),
				)
			} else {
				fmt.Fprintf(buf, "%s_%04d",
					sqlparser.String(ate.Expr),
					tableIndex,
				)
			}
		} else {
			fmt.Fprintf(buf, "%s_%04d",
				sqlparser.String(v.LeftExpr),
				tableIndex,
			)
		}
		buf.Fprintf(" %s %v", v.Join, v.RightExpr)
		if v.On != nil {
			buf.Fprintf(" on %v", v.On)
		}
	default:
		fmt.Fprintf(buf, "%s_%04d",
			sqlparser.String(node.From[0]),
			tableIndex,
		)
	}
	buf.Fprintf("%v%v%v%v%s",
		node.Where,
		node.GroupBy,
		node.Having,
		node.OrderBy,
		node.Lock,
	)
	return buf.String()
}
예제 #22
0
파일: router.go 프로젝트: flike/kingshard
//rewrite select sql
func (r *Router) rewriteSelectSql(plan *Plan, node *sqlparser.Select, tableIndex int) string {
	buf := sqlparser.NewTrackedBuffer(nil)
	buf.Fprintf("select %v%s",
		node.Comments,
		node.Distinct,
	)

	var prefix string
	//rewrite select expr
	for _, expr := range node.SelectExprs {
		switch v := expr.(type) {
		case *sqlparser.StarExpr:
			//for shardTable.*,need replace table into shardTable_xxxx.
			if string(v.TableName) == plan.Rule.Table {
				fmt.Fprintf(buf, "%s%s_%04d.*",
					prefix,
					plan.Rule.Table,
					tableIndex,
				)
			} else {
				buf.Fprintf("%s%v", prefix, expr)
			}
		case *sqlparser.NonStarExpr:
			//rewrite shardTable.column as a
			//into shardTable_xxxx.column as a
			if colName, ok := v.Expr.(*sqlparser.ColName); ok {
				if string(colName.Qualifier) == plan.Rule.Table {
					fmt.Fprintf(buf, "%s%s_%04d.%s",
						prefix,
						plan.Rule.Table,
						tableIndex,
						string(colName.Name),
					)
				} else {
					buf.Fprintf("%s%v", prefix, colName)
				}
				//if expr has as
				if v.As != nil {
					buf.Fprintf(" as %s", v.As)
				}
			} else {
				buf.Fprintf("%s%v", prefix, expr)
			}
		default:
			buf.Fprintf("%s%v", prefix, expr)
		}
		prefix = ", "
	}
	//insert the group columns in the first of select cloumns
	if len(node.GroupBy) != 0 {
		prefix = ","
		for _, n := range node.GroupBy {
			buf.Fprintf("%s%v", prefix, n)
		}
	}
	buf.Fprintf(" from ")
	switch v := (node.From[0]).(type) {
	case *sqlparser.AliasedTableExpr:
		if len(v.As) != 0 {
			fmt.Fprintf(buf, "%s_%04d as %s",
				sqlparser.String(v.Expr),
				tableIndex,
				string(v.As),
			)
		} else {
			fmt.Fprintf(buf, "%s_%04d",
				sqlparser.String(v.Expr),
				tableIndex,
			)
		}
	case *sqlparser.JoinTableExpr:
		if ate, ok := (v.LeftExpr).(*sqlparser.AliasedTableExpr); ok {
			if len(ate.As) != 0 {
				fmt.Fprintf(buf, "%s_%04d as %s",
					sqlparser.String(ate.Expr),
					tableIndex,
					string(ate.As),
				)
			} else {
				fmt.Fprintf(buf, "%s_%04d",
					sqlparser.String(ate.Expr),
					tableIndex,
				)
			}
		} else {
			fmt.Fprintf(buf, "%s_%04d",
				sqlparser.String(v.LeftExpr),
				tableIndex,
			)
		}
		buf.Fprintf(" %s %v", v.Join, v.RightExpr)
		if v.On != nil {
			buf.Fprintf(" on %v", v.On)
		}
	default:
		fmt.Fprintf(buf, "%s_%04d",
			sqlparser.String(node.From[0]),
			tableIndex,
		)
	}
	//append other tables
	prefix = ", "
	for i := 1; i < len(node.From); i++ {
		buf.Fprintf("%s%v", prefix, node.From[i])
	}

	newLimit, err := node.Limit.RewriteLimit()
	if err != nil {
		//do not change limit
		newLimit = node.Limit
	}
	//rewrite where
	oldright, err := plan.rewriteWhereIn(tableIndex)

	buf.Fprintf("%v%v%v%v%v%s",
		node.Where,
		node.GroupBy,
		node.Having,
		node.OrderBy,
		newLimit,
		node.Lock,
	)
	//restore old right
	if oldright != nil {
		plan.InRightToReplace.Right = oldright
	}
	return buf.String()
}