func makeSpan(tableDesc *sqlbase.TableDescriptor, i, j int) roachpb.Span { makeKey := func(val int) roachpb.Key { key, err := sqlbase.MakePrimaryIndexKey(tableDesc, val) if err != nil { panic(err) } return key } return roachpb.Span{ Key: makeKey(i), EndKey: makeKey(j), } }
// splitRangeAtVal splits the range for a table with schema // `CREATE TABLE test (k INT PRIMARY KEY)` at row with value pk (the row will be // the first on the right of the split). func splitRangeAtVal( ts *server.TestServer, tableDesc *sqlbase.TableDescriptor, pk int, ) (roachpb.RangeDescriptor, roachpb.RangeDescriptor, error) { if len(tableDesc.Indexes) != 0 { return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{}, errors.Errorf("expected table with just a PK, got: %+v", tableDesc) } pik, err := sqlbase.MakePrimaryIndexKey(tableDesc, pk) if err != nil { return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{}, err } startKey := keys.MakeRowSentinelKey(pik) leftRange, rightRange, err := ts.SplitRange(startKey) if err != nil { return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{}, errors.Wrapf(err, "failed to split at row: %d", pk) } return leftRange, rightRange, nil }
func TestDistSQLJoinAndAgg(t *testing.T) { defer leaktest.AfterTest(t)() // This test sets up a distributed join between two tables: // - a NumToSquare table of size N that maps integers from 1 to n to their // squares // - a NumToStr table of size N^2 that maps integers to their string // representations. This table is split and distributed to all the nodes. const n = 100 const numNodes = 5 tc := serverutils.StartTestCluster(t, numNodes, base.TestClusterArgs{ ReplicationMode: base.ReplicationManual, ServerArgs: base.TestServerArgs{ UseDatabase: "test", }, }) defer tc.Stopper().Stop() cdb := tc.Server(0).KVClient().(*client.DB) sqlutils.CreateTable( t, tc.ServerConn(0), "NumToSquare", "x INT PRIMARY KEY, xsquared INT", n, sqlutils.ToRowFn(sqlutils.RowIdxFn, func(row int) parser.Datum { return parser.NewDInt(parser.DInt(row * row)) }), ) sqlutils.CreateTable( t, tc.ServerConn(0), "NumToStr", "y INT PRIMARY KEY, str STRING", n*n, sqlutils.ToRowFn(sqlutils.RowIdxFn, sqlutils.RowEnglishFn), ) // Split the table into multiple ranges, with each range having a single // replica on a certain node. This forces the query to be distributed. // // TODO(radu): this approach should be generalized into test infrastructure // (perhaps by adding functionality to logic tests). // TODO(radu): we should verify that the plan is indeed distributed as // intended. descNumToStr := sqlbase.GetTableDescriptor(cdb, "test", "NumToStr") // split introduces a split and moves the right range to a given node. split := func(val int, targetNode int) { pik, err := sqlbase.MakePrimaryIndexKey(descNumToStr, val) if err != nil { t.Fatal(err) } splitKey := keys.MakeRowSentinelKey(pik) _, rightRange, err := tc.Server(0).SplitRange(splitKey) if err != nil { t.Fatal(err) } splitKey = rightRange.StartKey.AsRawKey() rightRange, err = tc.AddReplicas(splitKey, tc.Target(targetNode)) if err != nil { t.Fatal(err) } // This transfer is necessary to avoid waiting for the lease to expire when // removing the first replica. if err := tc.TransferRangeLease(rightRange, tc.Target(targetNode)); err != nil { t.Fatal(err) } if _, err := tc.RemoveReplicas(splitKey, tc.Target(0)); err != nil { t.Fatal(err) } } // split moves the right range, so we split things back to front. for i := numNodes - 1; i > 0; i-- { split(n*n/numNodes*i, i) } r := sqlutils.MakeSQLRunner(t, tc.ServerConn(0)) r.DB.SetMaxOpenConns(1) r.Exec("SET DIST_SQL = ALWAYS") res := r.QueryStr("SELECT x, str FROM NumToSquare JOIN NumToStr ON y = xsquared") // Verify that res contains one entry for each integer, with the string // representation of its square, e.g.: // [1, one] // [2, two] // [3, nine] // [4, one-six] // (but not necessarily in order). if len(res) != n { t.Fatalf("expected %d rows, got %d", n, len(res)) } resMap := make(map[int]string) for _, row := range res { if len(row) != 2 { t.Fatalf("invalid row %v", row) } n, err := strconv.Atoi(row[0]) if err != nil { t.Fatalf("error parsing row %v: %s", row, err) } resMap[n] = row[1] } for i := 1; i <= n; i++ { if resMap[i] != sqlutils.IntToEnglish(i*i) { t.Errorf("invalid string for %d: %s", i, resMap[i]) } } checkRes := func(exp int) bool { return len(res) == 1 && len(res[0]) == 1 && res[0][0] == strconv.Itoa(exp) } // Sum the numbers in the NumToStr table. res = r.QueryStr("SELECT SUM(y) FROM NumToStr") if exp := n * n * (n*n + 1) / 2; !checkRes(exp) { t.Errorf("expected [[%d]], got %s", exp, res) } // Count the rows in the NumToStr table. res = r.QueryStr("SELECT COUNT(*) FROM NumToStr") if !checkRes(n * n) { t.Errorf("expected [[%d]], got %s", n*n, res) } // Count how many numbers contain the digit 5. res = r.QueryStr("SELECT COUNT(*) FROM NumToStr WHERE str LIKE '%five%'") exp := 0 for i := 1; i <= n*n; i++ { for x := i; x > 0; x /= 10 { if x%10 == 5 { exp++ break } } } if !checkRes(exp) { t.Errorf("expected [[%d]], got %s", exp, res) } }
// checkDistAggregationInfo tests that a flow with multiple local stages and a // final stage (in accordance with per DistAggregationInfo) gets the same result // with a naive aggregation flow that has a single non-distributed stage. // // Both types of flows are set up and ran against the first numRows of the given // table. We assume the table's first column is the primary key, with values // from 1 to numRows. A non-PK column that works with the function is chosen. func checkDistAggregationInfo( t *testing.T, srv serverutils.TestServerInterface, tableDesc *sqlbase.TableDescriptor, colIdx int, numRows int, fn distsqlrun.AggregatorSpec_Func, info DistAggregationInfo, ) { colType := tableDesc.Columns[colIdx].Type makeTableReader := func(startPK, endPK int, streamID int) distsqlrun.ProcessorSpec { tr := distsqlrun.TableReaderSpec{ Table: *tableDesc, OutputColumns: []uint32{uint32(colIdx)}, Spans: make([]distsqlrun.TableReaderSpan, 1), } var err error tr.Spans[0].Span.Key, err = sqlbase.MakePrimaryIndexKey(tableDesc, startPK) if err != nil { t.Fatal(err) } tr.Spans[0].Span.EndKey, err = sqlbase.MakePrimaryIndexKey(tableDesc, endPK) if err != nil { t.Fatal(err) } return distsqlrun.ProcessorSpec{ Core: distsqlrun.ProcessorCoreUnion{TableReader: &tr}, Output: []distsqlrun.OutputRouterSpec{{ Type: distsqlrun.OutputRouterSpec_PASS_THROUGH, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(streamID)}, }, }}, } } // First run a flow that aggregates all the rows without any local stages. rowsNonDist := runTestFlow( t, srv, makeTableReader(1, numRows+1, 0), distsqlrun.ProcessorSpec{ Input: []distsqlrun.InputSyncSpec{{ Type: distsqlrun.InputSyncSpec_UNORDERED, ColumnTypes: []sqlbase.ColumnType{colType}, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: 0}, }, }}, Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{ Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: fn, ColIdx: 0}}, }}, Output: []distsqlrun.OutputRouterSpec{{ Type: distsqlrun.OutputRouterSpec_PASS_THROUGH, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE}, }, }}, }, ) // Now run a flow with 4 separate table readers, each with its own local // stage, all feeding into a single final stage. numParallel := 4 // The type outputted by the local stage can be different than the input type // (e.g. DECIMAL instead of INT). _, intermediaryType, err := distsqlrun.GetAggregateInfo(fn, colType) if err != nil { t.Fatal(err) } if numParallel < numRows { numParallel = numRows } finalProc := distsqlrun.ProcessorSpec{ Input: []distsqlrun.InputSyncSpec{{ Type: distsqlrun.InputSyncSpec_UNORDERED, ColumnTypes: []sqlbase.ColumnType{intermediaryType}, }}, Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{ Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.FinalStage, ColIdx: 0}}, }}, Output: []distsqlrun.OutputRouterSpec{{ Type: distsqlrun.OutputRouterSpec_PASS_THROUGH, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE}, }, }}, } var procs []distsqlrun.ProcessorSpec for i := 0; i < numParallel; i++ { tr := makeTableReader(1+i*numRows/numParallel, 1+(i+1)*numRows/numParallel, 2*i) agg := distsqlrun.ProcessorSpec{ Input: []distsqlrun.InputSyncSpec{{ Type: distsqlrun.InputSyncSpec_UNORDERED, ColumnTypes: []sqlbase.ColumnType{colType}, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2 * i)}, }, }}, Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{ Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.LocalStage, ColIdx: 0}}, }}, Output: []distsqlrun.OutputRouterSpec{{ Type: distsqlrun.OutputRouterSpec_PASS_THROUGH, Streams: []distsqlrun.StreamEndpointSpec{ {Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2*i + 1)}, }, }}, } procs = append(procs, tr, agg) finalProc.Input[0].Streams = append(finalProc.Input[0].Streams, distsqlrun.StreamEndpointSpec{ Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2*i + 1), }) } procs = append(procs, finalProc) rowsDist := runTestFlow(t, srv, procs...) if len(rowsDist[0]) != len(rowsNonDist[0]) { t.Errorf("different row lengths (dist: %d non-dist: %d)", len(rowsDist[0]), len(rowsNonDist[0])) } else { for i := range rowsDist[0] { tDist := rowsDist[0][i].Type.String() tNonDist := rowsNonDist[0][i].Type.String() if tDist != tNonDist { t.Errorf("different type for column %d (dist: %s non-dist: %s)", i, tDist, tNonDist) } } } if rowsDist.String() != rowsNonDist.String() { t.Errorf("different results\nw/o local stage: %s\nwith local stage: %s", rowsNonDist, rowsDist) } }