func TestClusterFlow(t *testing.T) { defer leaktest.AfterTest(t)() const numRows = 100 args := base.TestClusterArgs{ReplicationMode: base.ReplicationManual} tc := serverutils.StartTestCluster(t, 3, args) defer tc.Stopper().Stop() sumDigitsFn := func(row int) parser.Datum { sum := 0 for row > 0 { sum += row % 10 row /= 10 } return parser.NewDInt(parser.DInt(sum)) } sqlutils.CreateTable(t, tc.ServerConn(0), "t", "num INT PRIMARY KEY, digitsum INT, numstr STRING, INDEX s (digitsum)", numRows, sqlutils.ToRowFn(sqlutils.RowIdxFn, sumDigitsFn, sqlutils.RowEnglishFn)) kvDB := tc.Server(0).KVClient().(*client.DB) desc := sqlbase.GetTableDescriptor(kvDB, "test", "t") makeIndexSpan := func(start, end int) TableReaderSpan { var span roachpb.Span prefix := roachpb.Key(sqlbase.MakeIndexKeyPrefix(desc, desc.Indexes[0].ID)) span.Key = append(prefix, encoding.EncodeVarintAscending(nil, int64(start))...) span.EndKey = append(span.EndKey, prefix...) span.EndKey = append(span.EndKey, encoding.EncodeVarintAscending(nil, int64(end))...) return TableReaderSpan{Span: span} } // Set up table readers on three hosts feeding data into a join reader on // the third host. This is a basic test for the distributed flow // infrastructure, including local and remote streams. // // Note that the ranges won't necessarily be local to the table readers, but // that doesn't matter for the purposes of this test. // Start a span (useful to look at spans using Lighstep). sp, err := tracing.JoinOrNew(tracing.NewTracer(), nil, "cluster test") if err != nil { t.Fatal(err) } ctx := opentracing.ContextWithSpan(context.Background(), sp) defer sp.Finish() tr1 := TableReaderSpec{ Table: *desc, IndexIdx: 1, OutputColumns: []uint32{0, 1}, Spans: []TableReaderSpan{makeIndexSpan(0, 8)}, } tr2 := TableReaderSpec{ Table: *desc, IndexIdx: 1, OutputColumns: []uint32{0, 1}, Spans: []TableReaderSpan{makeIndexSpan(8, 12)}, } tr3 := TableReaderSpec{ Table: *desc, IndexIdx: 1, OutputColumns: []uint32{0, 1}, Spans: []TableReaderSpan{makeIndexSpan(12, 100)}, } jr := JoinReaderSpec{ Table: *desc, OutputColumns: []uint32{2}, } txn := client.NewTxn(ctx, *kvDB) fid := FlowID{uuid.MakeV4()} req1 := &SetupFlowRequest{Txn: txn.Proto} req1.Flow = FlowSpec{ FlowID: fid, Processors: []ProcessorSpec{{ Core: ProcessorCoreUnion{TableReader: &tr1}, Output: []OutputRouterSpec{{ Type: OutputRouterSpec_MIRROR, Streams: []StreamEndpointSpec{ {StreamID: 0, Mailbox: &MailboxSpec{TargetAddr: tc.Server(2).ServingAddr()}}, }, }}, }}, } req2 := &SetupFlowRequest{Txn: txn.Proto} req2.Flow = FlowSpec{ FlowID: fid, Processors: []ProcessorSpec{{ Core: ProcessorCoreUnion{TableReader: &tr2}, Output: []OutputRouterSpec{{ Type: OutputRouterSpec_MIRROR, Streams: []StreamEndpointSpec{ {StreamID: 1, Mailbox: &MailboxSpec{TargetAddr: tc.Server(2).ServingAddr()}}, }, }}, }}, } req3 := &SetupFlowRequest{Txn: txn.Proto} req3.Flow = FlowSpec{ FlowID: fid, Processors: []ProcessorSpec{ { Core: ProcessorCoreUnion{TableReader: &tr3}, Output: []OutputRouterSpec{{ Type: OutputRouterSpec_MIRROR, Streams: []StreamEndpointSpec{ {StreamID: StreamID(2)}, }, }}, }, { Input: []InputSyncSpec{{ Type: InputSyncSpec_ORDERED, Ordering: Ordering{Columns: []Ordering_Column{{1, Ordering_Column_ASC}}}, Streams: []StreamEndpointSpec{ {StreamID: 0, Mailbox: &MailboxSpec{}}, {StreamID: 1, Mailbox: &MailboxSpec{}}, {StreamID: StreamID(2)}, }, }}, Core: ProcessorCoreUnion{JoinReader: &jr}, Output: []OutputRouterSpec{{ Type: OutputRouterSpec_MIRROR, Streams: []StreamEndpointSpec{{Mailbox: &MailboxSpec{SimpleResponse: true}}}, }}}, }, } if err := SetFlowRequestTrace(ctx, req1); err != nil { t.Fatal(err) } if err := SetFlowRequestTrace(ctx, req2); err != nil { t.Fatal(err) } if err := SetFlowRequestTrace(ctx, req3); err != nil { t.Fatal(err) } var clients []DistSQLClient for i := 0; i < 3; i++ { s := tc.Server(i) conn, err := s.RPCContext().GRPCDial(s.ServingAddr()) if err != nil { t.Fatal(err) } clients = append(clients, NewDistSQLClient(conn)) } if log.V(1) { log.Infof(ctx, "Setting up flow on 0") } if resp, err := clients[0].SetupFlow(ctx, req1); err != nil { t.Fatal(err) } else if resp.Error != nil { t.Fatal(resp.Error) } if log.V(1) { log.Infof(ctx, "Setting up flow on 1") } if resp, err := clients[1].SetupFlow(ctx, req2); err != nil { t.Fatal(err) } else if resp.Error != nil { t.Fatal(resp.Error) } if log.V(1) { log.Infof(ctx, "Running flow on 2") } stream, err := clients[2].RunSimpleFlow(ctx, req3) if err != nil { t.Fatal(err) } var decoder StreamDecoder var rows sqlbase.EncDatumRows for { msg, err := stream.Recv() if err != nil { if err == io.EOF { break } t.Fatal(err) } err = decoder.AddMessage(msg) if err != nil { t.Fatal(err) } rows = testGetDecodedRows(t, &decoder, rows) } if done, trailerErr := decoder.IsDone(); !done { t.Fatal("stream not done") } else if trailerErr != nil { t.Fatal("error in the stream trailer:", trailerErr) } // The result should be all the numbers in string form, ordered by the // digit sum (and then by number). var results []string for sum := 1; sum <= 50; sum++ { for i := 1; i <= numRows; i++ { if int(*sumDigitsFn(i).(*parser.DInt)) == sum { results = append(results, fmt.Sprintf("['%s']", sqlutils.IntToEnglish(i))) } } } expected := strings.Join(results, " ") expected = "[" + expected + "]" if rowStr := rows.String(); rowStr != expected { t.Errorf("Result: %s\n Expected: %s\n", rowStr, expected) } }
func TestDistSQLJoinAndAgg(t *testing.T) { defer leaktest.AfterTest(t)() // This test sets up a distributed join between two tables: // - a NumToSquare table of size N that maps integers from 1 to n to their // squares // - a NumToStr table of size N^2 that maps integers to their string // representations. This table is split and distributed to all the nodes. const n = 100 const numNodes = 5 tc := serverutils.StartTestCluster(t, numNodes, base.TestClusterArgs{ ReplicationMode: base.ReplicationManual, ServerArgs: base.TestServerArgs{ UseDatabase: "test", }, }) defer tc.Stopper().Stop() cdb := tc.Server(0).KVClient().(*client.DB) sqlutils.CreateTable( t, tc.ServerConn(0), "NumToSquare", "x INT PRIMARY KEY, xsquared INT", n, sqlutils.ToRowFn(sqlutils.RowIdxFn, func(row int) parser.Datum { return parser.NewDInt(parser.DInt(row * row)) }), ) sqlutils.CreateTable( t, tc.ServerConn(0), "NumToStr", "y INT PRIMARY KEY, str STRING", n*n, sqlutils.ToRowFn(sqlutils.RowIdxFn, sqlutils.RowEnglishFn), ) // Split the table into multiple ranges, with each range having a single // replica on a certain node. This forces the query to be distributed. // // TODO(radu): this approach should be generalized into test infrastructure // (perhaps by adding functionality to logic tests). // TODO(radu): we should verify that the plan is indeed distributed as // intended. descNumToStr := sqlbase.GetTableDescriptor(cdb, "test", "NumToStr") // split introduces a split and moves the right range to a given node. split := func(val int, targetNode int) { pik, err := sqlbase.MakePrimaryIndexKey(descNumToStr, val) if err != nil { t.Fatal(err) } splitKey := keys.MakeRowSentinelKey(pik) _, rightRange, err := tc.Server(0).SplitRange(splitKey) if err != nil { t.Fatal(err) } splitKey = rightRange.StartKey.AsRawKey() rightRange, err = tc.AddReplicas(splitKey, tc.Target(targetNode)) if err != nil { t.Fatal(err) } // This transfer is necessary to avoid waiting for the lease to expire when // removing the first replica. if err := tc.TransferRangeLease(rightRange, tc.Target(targetNode)); err != nil { t.Fatal(err) } if _, err := tc.RemoveReplicas(splitKey, tc.Target(0)); err != nil { t.Fatal(err) } } // split moves the right range, so we split things back to front. for i := numNodes - 1; i > 0; i-- { split(n*n/numNodes*i, i) } r := sqlutils.MakeSQLRunner(t, tc.ServerConn(0)) r.DB.SetMaxOpenConns(1) r.Exec("SET DIST_SQL = ALWAYS") res := r.QueryStr("SELECT x, str FROM NumToSquare JOIN NumToStr ON y = xsquared") // Verify that res contains one entry for each integer, with the string // representation of its square, e.g.: // [1, one] // [2, two] // [3, nine] // [4, one-six] // (but not necessarily in order). if len(res) != n { t.Fatalf("expected %d rows, got %d", n, len(res)) } resMap := make(map[int]string) for _, row := range res { if len(row) != 2 { t.Fatalf("invalid row %v", row) } n, err := strconv.Atoi(row[0]) if err != nil { t.Fatalf("error parsing row %v: %s", row, err) } resMap[n] = row[1] } for i := 1; i <= n; i++ { if resMap[i] != sqlutils.IntToEnglish(i*i) { t.Errorf("invalid string for %d: %s", i, resMap[i]) } } checkRes := func(exp int) bool { return len(res) == 1 && len(res[0]) == 1 && res[0][0] == strconv.Itoa(exp) } // Sum the numbers in the NumToStr table. res = r.QueryStr("SELECT SUM(y) FROM NumToStr") if exp := n * n * (n*n + 1) / 2; !checkRes(exp) { t.Errorf("expected [[%d]], got %s", exp, res) } // Count the rows in the NumToStr table. res = r.QueryStr("SELECT COUNT(*) FROM NumToStr") if !checkRes(n * n) { t.Errorf("expected [[%d]], got %s", n*n, res) } // Count how many numbers contain the digit 5. res = r.QueryStr("SELECT COUNT(*) FROM NumToStr WHERE str LIKE '%five%'") exp := 0 for i := 1; i <= n*n; i++ { for x := i; x > 0; x /= 10 { if x%10 == 5 { exp++ break } } } if !checkRes(exp) { t.Errorf("expected [[%d]], got %s", exp, res) } }