func dumpTable(w io.Writer, conn *sqlConn, origDBName, origTableName string) error { const limit = 100 // Escape names since they can't be used in placeholders. dbname := parser.Name(origDBName).String() tablename := parser.Name(origTableName).String() if err := conn.Exec(fmt.Sprintf("SET DATABASE = %s", dbname), nil); err != nil { return err } // Fetch all table metadata in a transaction and its time to guarantee it // doesn't change between the various SHOW statements. if err := conn.Exec("BEGIN", nil); err != nil { return err } vals, err := conn.QueryRow("SELECT cluster_logical_timestamp()", nil) if err != nil { return err } clusterTS := string(vals[0].([]byte)) // A previous version of the code did a SELECT on system.descriptor. This // required the SELECT privilege to the descriptor table, which only root // has. Allowing non-root to do this would let users see other users' table // descriptors which is a problem in multi-tenancy. // Fetch column types. rows, err := conn.Query(fmt.Sprintf("SHOW COLUMNS FROM %s", tablename), nil) if err != nil { return err } vals = make([]driver.Value, 2) coltypes := make(map[string]string) for { if err := rows.Next(vals); err == io.EOF { break } else if err != nil { return err } nameI, typI := vals[0], vals[1] name, ok := nameI.(string) if !ok { return fmt.Errorf("unexpected value: %T", nameI) } typ, ok := typI.(string) if !ok { return fmt.Errorf("unexpected value: %T", typI) } coltypes[name] = typ } if err := rows.Close(); err != nil { return err } // index holds the names, in order, of the primary key columns. var index []string // Primary index is always the first index returned by SHOW INDEX. rows, err = conn.Query(fmt.Sprintf("SHOW INDEX FROM %s", tablename), nil) if err != nil { return err } vals = make([]driver.Value, 5) var primaryIndex string // Find the primary index columns. for { if err := rows.Next(vals); err == io.EOF { break } else if err != nil { return err } b, ok := vals[1].(string) if !ok { return fmt.Errorf("unexpected value: %T", vals[1]) } if primaryIndex == "" { primaryIndex = b } else if primaryIndex != b { break } b, ok = vals[4].(string) if !ok { return fmt.Errorf("unexpected value: %T", vals[4]) } index = append(index, parser.Name(b).String()) } if err := rows.Close(); err != nil { return err } if len(index) == 0 { return fmt.Errorf("no primary key index found") } indexes := strings.Join(index, ", ") // Build the SELECT query. var sbuf bytes.Buffer fmt.Fprintf(&sbuf, "SELECT %s, * FROM %s@%s AS OF SYSTEM TIME %s", indexes, tablename, primaryIndex, clusterTS) var wbuf bytes.Buffer fmt.Fprintf(&wbuf, " WHERE ROW (%s) > ROW (", indexes) for i := range index { if i > 0 { wbuf.WriteString(", ") } fmt.Fprintf(&wbuf, "$%d", i+1) } wbuf.WriteString(")") // No WHERE clause first time, so add a place to inject it. fmt.Fprintf(&sbuf, "%%s ORDER BY %s LIMIT %d", indexes, limit) bs := sbuf.String() vals, err = conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE %s", tablename), nil) if err != nil { return err } create := vals[1].(string) if _, err := w.Write([]byte(create)); err != nil { return err } if _, err := w.Write([]byte(";\n")); err != nil { return err } if err := conn.Exec("COMMIT", nil); err != nil { return err } // pk holds the last values of the fetched primary keys var pk []driver.Value q := fmt.Sprintf(bs, "") for { rows, err := conn.Query(q, pk) if err != nil { return err } cols := rows.Columns() pkcols := cols[:len(index)] cols = cols[len(index):] inserts := make([][]string, 0, limit) i := 0 for i < limit { vals := make([]driver.Value, len(cols)+len(pkcols)) if err := rows.Next(vals); err == io.EOF { break } else if err != nil { return err } if pk == nil { q = fmt.Sprintf(bs, wbuf.String()) } pk = vals[:len(index)] vals = vals[len(index):] ivals := make([]string, len(vals)) // Values need to be correctly encoded for INSERT statements in a text file. for si, sv := range vals { switch t := sv.(type) { case nil: ivals[si] = "NULL" case bool: ivals[si] = parser.MakeDBool(parser.DBool(t)).String() case int64: ivals[si] = parser.NewDInt(parser.DInt(t)).String() case float64: ivals[si] = parser.NewDFloat(parser.DFloat(t)).String() case string: ivals[si] = parser.NewDString(t).String() case []byte: switch ct := coltypes[cols[si]]; ct { case "INTERVAL": ivals[si] = fmt.Sprintf("'%s'", t) case "BYTES": ivals[si] = parser.NewDBytes(parser.DBytes(t)).String() default: // STRING and DECIMAL types can have optional length // suffixes, so only examine the prefix of the type. if strings.HasPrefix(coltypes[cols[si]], "STRING") { ivals[si] = parser.NewDString(string(t)).String() } else if strings.HasPrefix(coltypes[cols[si]], "DECIMAL") { ivals[si] = string(t) } else { panic(errors.Errorf("unknown []byte type: %s, %v: %s", t, cols[si], coltypes[cols[si]])) } } case time.Time: var d parser.Datum ct := coltypes[cols[si]] switch ct { case "DATE": d = parser.NewDDateFromTime(t, time.UTC) case "TIMESTAMP": d = parser.MakeDTimestamp(t, time.Nanosecond) case "TIMESTAMP WITH TIME ZONE": d = parser.MakeDTimestampTZ(t, time.Nanosecond) default: panic(errors.Errorf("unknown timestamp type: %s, %v: %s", t, cols[si], coltypes[cols[si]])) } ivals[si] = fmt.Sprintf("'%s'", d) default: panic(errors.Errorf("unknown field type: %T (%s)", t, cols[si])) } } inserts = append(inserts, ivals) i++ } for si, sv := range pk { b, ok := sv.([]byte) if ok && strings.HasPrefix(coltypes[pkcols[si]], "STRING") { // Primary key strings need to be converted to a go string, but not SQL // encoded since they aren't being written to a text file. pk[si] = string(b) } } if err := rows.Close(); err != nil { return err } if i == 0 { break } fmt.Fprintf(w, "\nINSERT INTO %s VALUES", tablename) for idx, values := range inserts { if idx > 0 { fmt.Fprint(w, ",") } fmt.Fprint(w, "\n\t(") for vi, v := range values { if vi > 0 { fmt.Fprint(w, ", ") } fmt.Fprint(w, v) } fmt.Fprint(w, ")") } fmt.Fprintln(w, ";") if i < limit { break } } return nil }
// decodeOidDatum decodes bytes with specified Oid and format code into // a datum. func decodeOidDatum(id oid.Oid, code formatCode, b []byte) (parser.Datum, error) { var d parser.Datum switch id { case oid.T_bool: switch code { case formatText: v, err := strconv.ParseBool(string(b)) if err != nil { return d, err } d = parser.MakeDBool(parser.DBool(v)) case formatBinary: switch b[0] { case 0: d = parser.MakeDBool(false) case 1: d = parser.MakeDBool(true) default: return d, errors.Errorf("unsupported binary bool: %q", b) } default: return d, errors.Errorf("unsupported bool format code: %d", code) } case oid.T_int2: switch code { case formatText: i, err := strconv.ParseInt(string(b), 10, 64) if err != nil { return d, err } d = parser.NewDInt(parser.DInt(i)) case formatBinary: if len(b) < 2 { return d, errors.Errorf("int2 requires 2 bytes for binary format") } i := int16(binary.BigEndian.Uint16(b)) d = parser.NewDInt(parser.DInt(i)) default: return d, errors.Errorf("unsupported int2 format code: %d", code) } case oid.T_int4: switch code { case formatText: i, err := strconv.ParseInt(string(b), 10, 64) if err != nil { return d, err } d = parser.NewDInt(parser.DInt(i)) case formatBinary: if len(b) < 4 { return d, errors.Errorf("int4 requires 4 bytes for binary format") } i := int32(binary.BigEndian.Uint32(b)) d = parser.NewDInt(parser.DInt(i)) default: return d, errors.Errorf("unsupported int4 format code: %d", code) } case oid.T_int8: switch code { case formatText: i, err := strconv.ParseInt(string(b), 10, 64) if err != nil { return d, err } d = parser.NewDInt(parser.DInt(i)) case formatBinary: if len(b) < 8 { return d, errors.Errorf("int8 requires 8 bytes for binary format") } i := int64(binary.BigEndian.Uint64(b)) d = parser.NewDInt(parser.DInt(i)) default: return d, errors.Errorf("unsupported int8 format code: %d", code) } case oid.T_float4: switch code { case formatText: f, err := strconv.ParseFloat(string(b), 64) if err != nil { return d, err } d = parser.NewDFloat(parser.DFloat(f)) case formatBinary: if len(b) < 4 { return d, errors.Errorf("float4 requires 4 bytes for binary format") } f := math.Float32frombits(binary.BigEndian.Uint32(b)) d = parser.NewDFloat(parser.DFloat(f)) default: return d, errors.Errorf("unsupported float4 format code: %d", code) } case oid.T_float8: switch code { case formatText: f, err := strconv.ParseFloat(string(b), 64) if err != nil { return d, err } d = parser.NewDFloat(parser.DFloat(f)) case formatBinary: if len(b) < 8 { return d, errors.Errorf("float8 requires 8 bytes for binary format") } f := math.Float64frombits(binary.BigEndian.Uint64(b)) d = parser.NewDFloat(parser.DFloat(f)) default: return d, errors.Errorf("unsupported float8 format code: %d", code) } case oid.T_numeric: switch code { case formatText: dd := &parser.DDecimal{} if _, ok := dd.SetString(string(b)); !ok { return nil, errors.Errorf("could not parse string %q as decimal", b) } d = dd case formatBinary: r := bytes.NewReader(b) alloc := struct { pgNum pgNumeric i16 int16 dd parser.DDecimal }{} for _, ptr := range []interface{}{ &alloc.pgNum.ndigits, &alloc.pgNum.weight, &alloc.pgNum.sign, &alloc.pgNum.dscale, } { if err := binary.Read(r, binary.BigEndian, ptr); err != nil { return d, err } } if alloc.pgNum.ndigits > 0 { decDigits := make([]byte, 0, alloc.pgNum.ndigits*pgDecDigits) nextDigit := func() error { if err := binary.Read(r, binary.BigEndian, &alloc.i16); err != nil { return err } numZeroes := pgDecDigits for i16 := alloc.i16; i16 > 0; i16 /= 10 { numZeroes-- } for ; numZeroes > 0; numZeroes-- { decDigits = append(decDigits, '0') } return nil } for i := int16(0); i < alloc.pgNum.ndigits-1; i++ { if err := nextDigit(); err != nil { return d, err } if alloc.i16 > 0 { decDigits = strconv.AppendUint(decDigits, uint64(alloc.i16), 10) } } // The last digit may contain padding, which we need to deal with. if err := nextDigit(); err != nil { return d, err } dscale := (alloc.pgNum.ndigits - (alloc.pgNum.weight + 1)) * pgDecDigits if overScale := dscale - alloc.pgNum.dscale; overScale > 0 { dscale -= overScale for i := int16(0); i < overScale; i++ { alloc.i16 /= 10 } } decDigits = strconv.AppendUint(decDigits, uint64(alloc.i16), 10) decString := string(decDigits) if _, ok := alloc.dd.UnscaledBig().SetString(decString, 10); !ok { return nil, errors.Errorf("could not parse string %q as decimal", decString) } alloc.dd.SetScale(inf.Scale(dscale)) } switch alloc.pgNum.sign { case pgNumericPos: case pgNumericNeg: alloc.dd.Neg(&alloc.dd.Dec) default: return d, errors.Errorf("unsupported numeric sign: %d", alloc.pgNum.sign) } d = &alloc.dd default: return d, errors.Errorf("unsupported numeric format code: %d", code) } case oid.T_text, oid.T_varchar: switch code { case formatText, formatBinary: d = parser.NewDString(string(b)) default: return d, errors.Errorf("unsupported text format code: %d", code) } case oid.T_bytea: switch code { case formatText: // http://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667 // Code cribbed from github.com/lib/pq. // We only support hex encoding. if len(b) >= 2 && bytes.Equal(b[:2], []byte("\\x")) { b = b[2:] // trim off leading "\\x" result := make([]byte, hex.DecodedLen(len(b))) _, err := hex.Decode(result, b) if err != nil { return d, err } d = parser.NewDBytes(parser.DBytes(result)) } else { return d, errors.Errorf("unsupported bytea encoding: %q", b) } case formatBinary: d = parser.NewDBytes(parser.DBytes(b)) default: return d, errors.Errorf("unsupported bytea format code: %d", code) } case oid.T_timestamp: switch code { case formatText: ts, err := parseTs(string(b)) if err != nil { return d, errors.Errorf("could not parse string %q as timestamp", b) } d = parser.MakeDTimestamp(ts, time.Microsecond) case formatBinary: if len(b) < 8 { return d, errors.Errorf("timestamp requires 8 bytes for binary format") } i := int64(binary.BigEndian.Uint64(b)) d = parser.MakeDTimestamp(pgBinaryToTime(i), time.Microsecond) default: return d, errors.Errorf("unsupported timestamp format code: %d", code) } case oid.T_timestamptz: switch code { case formatText: ts, err := parseTs(string(b)) if err != nil { return d, errors.Errorf("could not parse string %q as timestamp", b) } d = parser.MakeDTimestampTZ(ts, time.Microsecond) case formatBinary: if len(b) < 8 { return d, errors.Errorf("timestamptz requires 8 bytes for binary format") } i := int64(binary.BigEndian.Uint64(b)) d = parser.MakeDTimestampTZ(pgBinaryToTime(i), time.Microsecond) default: return d, errors.Errorf("unsupported timestamptz format code: %d", code) } case oid.T_date: switch code { case formatText: ts, err := parseTs(string(b)) if err != nil { res, err := parser.ParseDDate(string(b), time.UTC) if err != nil { return d, errors.Errorf("could not parse string %q as date", b) } d = res } else { daysSinceEpoch := ts.Unix() / secondsInDay d = parser.NewDDate(parser.DDate(daysSinceEpoch)) } case formatBinary: if len(b) < 4 { return d, errors.Errorf("date requires 4 bytes for binary format") } i := int32(binary.BigEndian.Uint32(b)) d = pgBinaryToDate(i) default: return d, errors.Errorf("unsupported date format code: %d", code) } case oid.T_interval: switch code { case formatText: d, err := parser.ParseDInterval(string(b)) if err != nil { return d, errors.Errorf("could not parse string %q as interval", b) } return d, nil default: return d, errors.Errorf("unsupported interval format code: %d", code) } default: return d, errors.Errorf("unsupported OID: %v", id) } return d, nil }