func (qs *QuerySplitter) parseInt(pkMinMax *mproto.QueryResult) ([]sqltypes.Value, error) { boundaries := []sqltypes.Value{} minNumeric := sqltypes.MakeNumeric(pkMinMax.Rows[0][0].Raw()) maxNumeric := sqltypes.MakeNumeric(pkMinMax.Rows[0][1].Raw()) if pkMinMax.Rows[0][0].Raw()[0] == '-' { // signed values, use int64 min, err := minNumeric.ParseInt64() if err != nil { return nil, err } max, err := maxNumeric.ParseInt64() if err != nil { return nil, err } interval := (max - min) / int64(qs.splitCount) if interval == 0 { return nil, err } qs.rowCount = interval for i := int64(1); i < int64(qs.splitCount); i++ { v, err := sqltypes.BuildValue(min + interval*i) if err != nil { return nil, err } boundaries = append(boundaries, v) } return boundaries, nil } // unsigned values, use uint64 min, err := minNumeric.ParseUint64() if err != nil { return nil, err } max, err := maxNumeric.ParseUint64() if err != nil { return nil, err } interval := (max - min) / uint64(qs.splitCount) if interval == 0 { return nil, err } qs.rowCount = int64(interval) for i := uint64(1); i < uint64(qs.splitCount); i++ { v, err := sqltypes.BuildValue(min + interval*i) if err != nil { return nil, err } boundaries = append(boundaries, v) } return boundaries, nil }
func (qs *QuerySplitter) parseFloat(pkMinMax *mproto.QueryResult) ([]sqltypes.Value, error) { boundaries := []sqltypes.Value{} min, err := strconv.ParseFloat(pkMinMax.Rows[0][0].String(), 64) if err != nil { return nil, err } max, err := strconv.ParseFloat(pkMinMax.Rows[0][1].String(), 64) if err != nil { return nil, err } interval := (max - min) / float64(qs.splitCount) if interval == 0 { return nil, err } qs.rowCount = int64(interval) for i := 1; i < qs.splitCount; i++ { boundary := min + interval*float64(i) v, err := sqltypes.BuildValue(boundary) if err != nil { return nil, err } boundaries = append(boundaries, v) } return boundaries, nil }
func (rci *RowcacheInvalidator) handleDMLEvent(event *blproto.StreamEvent) { invalidations := int64(0) tableInfo := rci.qe.schemaInfo.GetTable(event.TableName) if tableInfo == nil { panic(NewTabletError(FAIL, "Table %s not found", event.TableName)) } if tableInfo.CacheType == schema.CACHE_NONE { return } sqlTypeKeys := make([]sqltypes.Value, 0, len(event.PKColNames)) for _, pkTuple := range event.PKValues { sqlTypeKeys = sqlTypeKeys[:0] for _, pkVal := range pkTuple { key, err := sqltypes.BuildValue(pkVal) if err != nil { log.Errorf("Error building invalidation key for %#v: '%v'", event, err) internalErrors.Add("Invalidation", 1) return } sqlTypeKeys = append(sqlTypeKeys, key) } newKey := validateKey(tableInfo, buildKey(sqlTypeKeys)) if newKey == "" { continue } tableInfo.Cache.Delete(newKey) invalidations++ } tableInfo.invalidations.Add(invalidations) }
func TestSplitQuery(t *testing.T) { schemaInfo := getSchemaInfo() query := &proto.BoundQuery{ Sql: "select * from test_table where count > :count", } splitter := NewQuerySplitter(query, 3, schemaInfo) splitter.validateQuery() min, _ := sqltypes.BuildValue(0) max, _ := sqltypes.BuildValue(300) minField := mproto.Field{ Name: "min", Type: mproto.VT_LONGLONG, } maxField := mproto.Field{ Name: "min", Type: mproto.VT_LONGLONG, } fields := []mproto.Field{minField, maxField} row := []sqltypes.Value{min, max} rows := [][]sqltypes.Value{row} pkMinMax := &mproto.QueryResult{ Rows: rows, Fields: fields, } splits := splitter.split(pkMinMax) got := []string{} for _, split := range splits { if split.RowCount != 100 { t.Errorf("wrong RowCount, got: %v, want: %v", split.RowCount, 100) } got = append(got, split.Query.Sql) } want := []string{ "select * from test_table where count > :count and id < 100", "select * from test_table where count > :count and id >= 100 and id < 200", "select * from test_table where count > :count and id >= 200", } if !reflect.DeepEqual(got, want) { t.Errorf("wrong splits, got: %v, want: %v", got, want) } }
func TestBuildStreamComment(t *testing.T) { pk1 := "pk1" pk2 := "pk2" tableInfo := createTableInfo("Table", map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"}, []string{pk1, pk2}) // set pk2 = 'xyz' where pk1=1 and pk2 = 'abc' bindVars := map[string]interface{}{} pk1Val, _ := sqltypes.BuildValue(1) pk2Val, _ := sqltypes.BuildValue("abc") pkValues := []interface{}{pk1Val, pk2Val} pkList, _ := buildValueList(&tableInfo, pkValues, bindVars) pk2SecVal, _ := sqltypes.BuildValue("xyz") secondaryPKValues := []interface{}{nil, pk2SecVal} secondaryList, _ := buildSecondaryList(&tableInfo, pkList, secondaryPKValues, bindVars) want := []byte(" /* _stream Table (pk1 pk2 ) (1 'YWJj' ) (1 'eHl6' ); */") got := buildStreamComment(&tableInfo, pkList, secondaryList) if !reflect.DeepEqual(got, want) { t.Errorf("case 1 failed, got %v, want %v", got, want) } }
func TestBuildSecondaryList(t *testing.T) { pk1 := "pk1" pk2 := "pk2" tableInfo := createTableInfo("Table", map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"}, []string{pk1, pk2}) // set pk2 = 'xyz' where pk1=1 and pk2 = 'abc' bindVars := map[string]interface{}{} pk1Val, _ := sqltypes.BuildValue(1) pk2Val, _ := sqltypes.BuildValue("abc") pkValues := []interface{}{pk1Val, pk2Val} pkList, _ := buildValueList(&tableInfo, pkValues, bindVars) pk2SecVal, _ := sqltypes.BuildValue("xyz") secondaryPKValues := []interface{}{nil, pk2SecVal} // want [[1 xyz]] want := [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2SecVal}} got, _ := buildSecondaryList(&tableInfo, pkList, secondaryPKValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("case 1 failed, got %v, want %v", got, want) } }
func EncodeValue(buf *bytes.Buffer, value interface{}) error { switch bindVal := value.(type) { case nil: buf.WriteString("null") case []sqltypes.Value: for i := 0; i < len(bindVal); i++ { if i != 0 { buf.WriteString(", ") } if err := EncodeValue(buf, bindVal[i]); err != nil { return err } } case [][]sqltypes.Value: for i := 0; i < len(bindVal); i++ { if i != 0 { buf.WriteString(", ") } buf.WriteByte('(') if err := EncodeValue(buf, bindVal[i]); err != nil { return err } buf.WriteByte(')') } case []interface{}: buf.WriteByte('(') for i, v := range bindVal { if i != 0 { buf.WriteString(", ") } if err := EncodeValue(buf, v); err != nil { return err } } buf.WriteByte(')') case TupleEqualityList: if err := bindVal.Encode(buf); err != nil { return err } default: v, err := sqltypes.BuildValue(bindVal) if err != nil { return err } v.EncodeSql(buf) } return nil }
func TestGetSplitBoundaries(t *testing.T) { min, _ := sqltypes.BuildValue(10) max, _ := sqltypes.BuildValue(60) row := []sqltypes.Value{min, max} rows := [][]sqltypes.Value{row} minField := mproto.Field{Name: "min", Type: mproto.VT_LONGLONG} maxField := mproto.Field{Name: "max", Type: mproto.VT_LONGLONG} fields := []mproto.Field{minField, maxField} pkMinMax := &mproto.QueryResult{ Fields: fields, Rows: rows, } splitter := &QuerySplitter{} splitter.splitCount = 5 boundaries := splitter.getSplitBoundaries(pkMinMax) if len(boundaries) != splitter.splitCount-1 { t.Errorf("wrong number of boundaries got: %v, want: %v", len(boundaries), splitter.splitCount-1) } got := splitter.getSplitBoundaries(pkMinMax) want := []sqltypes.Value{buildVal(20), buildVal(30), buildVal(40), buildVal(50)} if !reflect.DeepEqual(got, want) { t.Errorf("incorrect boundaries, got: %v, want: %v", got, want) } // Test negative min value min, _ = sqltypes.BuildValue(-100) max, _ = sqltypes.BuildValue(100) row = []sqltypes.Value{min, max} rows = [][]sqltypes.Value{row} pkMinMax.Rows = rows got = splitter.getSplitBoundaries(pkMinMax) want = []sqltypes.Value{buildVal(-60), buildVal(-20), buildVal(20), buildVal(60)} if !reflect.DeepEqual(got, want) { t.Errorf("incorrect boundaries, got: %v, want: %v", got, want) } // Test float min max min, _ = sqltypes.BuildValue(10.5) max, _ = sqltypes.BuildValue(60.5) row = []sqltypes.Value{min, max} rows = [][]sqltypes.Value{row} minField = mproto.Field{Name: "min", Type: mproto.VT_DOUBLE} maxField = mproto.Field{Name: "max", Type: mproto.VT_DOUBLE} fields = []mproto.Field{minField, maxField} pkMinMax.Rows = rows pkMinMax.Fields = fields got = splitter.getSplitBoundaries(pkMinMax) want = []sqltypes.Value{buildVal(20.5), buildVal(30.5), buildVal(40.5), buildVal(50.5)} if !reflect.DeepEqual(got, want) { t.Errorf("incorrect boundaries, got: %v, want: %v", got, want) } }
func resolveListArg(col *schema.TableColumn, key string, bindVars map[string]interface{}) ([]sqltypes.Value, error) { val, _, err := sqlparser.FetchBindVar(key, bindVars) if err != nil { return nil, NewTabletError(FAIL, "%v", err) } list := val.([]interface{}) resolved := make([]sqltypes.Value, len(list)) for i, v := range list { sqlval, err := sqltypes.BuildValue(v) if err != nil { return nil, NewTabletError(FAIL, "%v", err) } if err = validateValue(col, sqlval); err != nil { return nil, err } resolved[i] = sqlval } return resolved, nil }
func getSchemaInfo() *SchemaInfo { table := &schema.Table{ Name: "test_table", } zero, _ := sqltypes.BuildValue(0) table.AddColumn("id", "int", zero, "") table.AddColumn("count", "int", zero, "") table.PKColumns = []int{0} tables := make(map[string]*TableInfo, 1) tables["test_table"] = &TableInfo{Table: table} tableNoPK := &schema.Table{ Name: "test_table_no_pk", } tableNoPK.AddColumn("id", "int", zero, "") tableNoPK.PKColumns = []int{} tables["test_table_no_pk"] = &TableInfo{Table: tableNoPK} return &SchemaInfo{tables: tables} }
func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[string]interface{}) (result sqltypes.Value, err error) { switch v := value.(type) { case string: val, _, err := sqlparser.FetchBindVar(v, bindVars) if err != nil { return result, NewTabletError(FAIL, "%v", err) } sqlval, err := sqltypes.BuildValue(val) if err != nil { return result, NewTabletError(FAIL, "%v", err) } result = sqlval case sqltypes.Value: result = v default: panic(fmt.Sprintf("incompatible value type %v", v)) } if err = validateValue(col, result); err != nil { return result, err } return result, nil }
func TestBuildValuesList(t *testing.T) { tableInfo := createTableInfo("Table", map[string]string{"pk1": "int", "pk2": "varbinary(128)", "col1": "int"}, []string{"pk1", "pk2"}) // simple PK clause. e.g. where pk1 = 1 bindVars := map[string]interface{}{} pk1Val, _ := sqltypes.BuildValue(1) pkValues := []interface{}{pk1Val} // want [[1]] want := [][]sqltypes.Value{[]sqltypes.Value{pk1Val}} got, _ := buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // simple PK clause with bindVars. e.g. where pk1 = :pk1 bindVars["pk1"] = 1 pkValues = []interface{}{":pk1"} // want [[1]] want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // null value bindVars["pk1"] = nil pkValues = []interface{}{":pk1"} // want [[1]] want = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.Value{}}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // invalid value bindVars["pk1"] = struct{}{} pkValues = []interface{}{":pk1"} wantErr := "error: unsupported bind variable type struct {}: {}" got, err := buildValueList(&tableInfo, pkValues, bindVars) if err == nil || err.Error() != wantErr { t.Errorf("got %v, want %v", err, wantErr) } // type mismatch int bindVars["pk1"] = "str" pkValues = []interface{}{":pk1"} wantErr = "error: type mismatch, expecting numeric type for str" got, err = buildValueList(&tableInfo, pkValues, bindVars) if err == nil || err.Error() != wantErr { t.Errorf("got %v, want %v", err, wantErr) } // type mismatch binary bindVars["pk1"] = 1 bindVars["pk2"] = 1 pkValues = []interface{}{":pk1", ":pk2"} wantErr = "error: type mismatch, expecting string type for 1" got, err = buildValueList(&tableInfo, pkValues, bindVars) t.Logf("%v", got) if err == nil || err.Error() != wantErr { t.Errorf("got %v, want %v", err, wantErr) } // composite PK clause. e.g. where pk1 = 1 and pk2 = "abc" pk2Val, _ := sqltypes.BuildValue("abc") pkValues = []interface{}{pk1Val, pk2Val} // want [[1 abc]] want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val, pk2Val}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // multi row composite PK insert // e.g. insert into Table(pk1,pk2) values (1, "abc"), (2, "xyz") pk1Val2, _ := sqltypes.BuildValue(2) pk2Val2, _ := sqltypes.BuildValue("xyz") pkValues = []interface{}{ []interface{}{pk1Val, pk1Val2}, []interface{}{pk2Val, pk2Val2}, } // want [[1 abc][2 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, []sqltypes.Value{pk1Val2, pk2Val2}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // composite PK IN clause // e.g. where pk1 = 1 and pk2 IN ("abc", "xyz") pkValues = []interface{}{ pk1Val, []interface{}{pk2Val, pk2Val2}, } // want [[1 abc][1 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, []sqltypes.Value{pk1Val, pk2Val2}, } got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // list arg // e.g. where pk1 = 1 and pk2 IN ::list bindVars = map[string]interface{}{ "list": []interface{}{ "abc", "xyz", }, } pkValues = []interface{}{ pk1Val, "::list", } // want [[1 abc][1 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, []sqltypes.Value{pk1Val, pk2Val2}, } // list arg one value // e.g. where pk1 = 1 and pk2 IN ::list bindVars = map[string]interface{}{ "list": []interface{}{ "abc", }, } pkValues = []interface{}{ pk1Val, "::list", } // want [[1 abc][1 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, } got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } // list arg empty list bindVars = map[string]interface{}{ "list": []interface{}{}, } pkValues = []interface{}{ pk1Val, "::list", } wantErr = "error: empty list supplied for list" got, err = buildValueList(&tableInfo, pkValues, bindVars) if err == nil || err.Error() != wantErr { t.Errorf("got %v, want %v", err, wantErr) } // list arg for non-list bindVars = map[string]interface{}{ "list": []interface{}{}, } pkValues = []interface{}{ pk1Val, ":list", } wantErr = "error: unexpected arg type []interface {} for key list" got, err = buildValueList(&tableInfo, pkValues, bindVars) if err == nil || err.Error() != wantErr { t.Errorf("got %v, want %v", err, wantErr) } }
func TestGetWhereClause(t *testing.T) { splitter := &QuerySplitter{} sql := "select * from test_table where count > :count" statement, _ := sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) splitter.pkCol = "id" // no boundary case, start = end = nil, should not change the where clause nilValue := sqltypes.Value{} clause := splitter.getWhereClause(nilValue, nilValue) want := " where count > :count" got := sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set lower bound, should add the lower bound condition to where clause start, _ := sqltypes.BuildValue(20) clause = splitter.getWhereClause(start, nilValue) want = " where count > :count and id >= 20" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set upper bound, should add the upper bound condition to where clause end, _ := sqltypes.BuildValue(40) clause = splitter.getWhereClause(nilValue, end) want = " where count > :count and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where count > :count and id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Original query with no where clause sql = "select * from test_table" statement, _ = sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) // no boundary case, start = end = nil should return no where clause clause = splitter.getWhereClause(nilValue, nilValue) want = "" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } }
func buildVal(val interface{}) sqltypes.Value { v, _ := sqltypes.BuildValue(val) return v }