예제 #1
0
파일: dump.go 프로젝트: hvaara/cockroach
func dumpTable(w io.Writer, conn *sqlConn, origDBName, origTableName string) error {
	const limit = 100

	// Escape names since they can't be used in placeholders.
	dbname := parser.Name(origDBName).String()
	tablename := parser.Name(origTableName).String()

	if err := conn.Exec(fmt.Sprintf("SET DATABASE = %s", dbname), nil); err != nil {
		return err
	}

	// Fetch all table metadata in a transaction and its time to guarantee it
	// doesn't change between the various SHOW statements.
	if err := conn.Exec("BEGIN", nil); err != nil {
		return err
	}

	vals, err := conn.QueryRow("SELECT cluster_logical_timestamp()", nil)
	if err != nil {
		return err
	}
	clusterTS := string(vals[0].([]byte))

	// A previous version of the code did a SELECT on system.descriptor. This
	// required the SELECT privilege to the descriptor table, which only root
	// has. Allowing non-root to do this would let users see other users' table
	// descriptors which is a problem in multi-tenancy.

	// Fetch column types.
	rows, err := conn.Query(fmt.Sprintf("SHOW COLUMNS FROM %s", tablename), nil)
	if err != nil {
		return err
	}
	vals = make([]driver.Value, 2)
	coltypes := make(map[string]string)
	for {
		if err := rows.Next(vals); err == io.EOF {
			break
		} else if err != nil {
			return err
		}
		nameI, typI := vals[0], vals[1]
		name, ok := nameI.(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", nameI)
		}
		typ, ok := typI.(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", typI)
		}
		coltypes[name] = typ
	}
	if err := rows.Close(); err != nil {
		return err
	}

	// index holds the names, in order, of the primary key columns.
	var index []string
	// Primary index is always the first index returned by SHOW INDEX.
	rows, err = conn.Query(fmt.Sprintf("SHOW INDEX FROM %s", tablename), nil)
	if err != nil {
		return err
	}
	vals = make([]driver.Value, 5)
	var primaryIndex string
	// Find the primary index columns.
	for {
		if err := rows.Next(vals); err == io.EOF {
			break
		} else if err != nil {
			return err
		}
		b, ok := vals[1].(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", vals[1])
		}
		if primaryIndex == "" {
			primaryIndex = b
		} else if primaryIndex != b {
			break
		}
		b, ok = vals[4].(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", vals[4])
		}
		index = append(index, parser.Name(b).String())
	}
	if err := rows.Close(); err != nil {
		return err
	}
	if len(index) == 0 {
		return fmt.Errorf("no primary key index found")
	}
	indexes := strings.Join(index, ", ")

	// Build the SELECT query.
	var sbuf bytes.Buffer
	fmt.Fprintf(&sbuf, "SELECT %s, * FROM %s@%s AS OF SYSTEM TIME %s", indexes, tablename, primaryIndex, clusterTS)

	var wbuf bytes.Buffer
	fmt.Fprintf(&wbuf, " WHERE ROW (%s) > ROW (", indexes)
	for i := range index {
		if i > 0 {
			wbuf.WriteString(", ")
		}
		fmt.Fprintf(&wbuf, "$%d", i+1)
	}
	wbuf.WriteString(")")
	// No WHERE clause first time, so add a place to inject it.
	fmt.Fprintf(&sbuf, "%%s ORDER BY %s LIMIT %d", indexes, limit)
	bs := sbuf.String()

	vals, err = conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE %s", tablename), nil)
	if err != nil {
		return err
	}
	create := vals[1].(string)
	if _, err := w.Write([]byte(create)); err != nil {
		return err
	}
	if _, err := w.Write([]byte(";\n")); err != nil {
		return err
	}

	if err := conn.Exec("COMMIT", nil); err != nil {
		return err
	}

	// pk holds the last values of the fetched primary keys
	var pk []driver.Value
	q := fmt.Sprintf(bs, "")
	for {
		rows, err := conn.Query(q, pk)
		if err != nil {
			return err
		}
		cols := rows.Columns()
		pkcols := cols[:len(index)]
		cols = cols[len(index):]
		inserts := make([][]string, 0, limit)
		i := 0
		for i < limit {
			vals := make([]driver.Value, len(cols)+len(pkcols))
			if err := rows.Next(vals); err == io.EOF {
				break
			} else if err != nil {
				return err
			}
			if pk == nil {
				q = fmt.Sprintf(bs, wbuf.String())
			}
			pk = vals[:len(index)]
			vals = vals[len(index):]
			ivals := make([]string, len(vals))
			// Values need to be correctly encoded for INSERT statements in a text file.
			for si, sv := range vals {
				switch t := sv.(type) {
				case nil:
					ivals[si] = "NULL"
				case bool:
					ivals[si] = parser.MakeDBool(parser.DBool(t)).String()
				case int64:
					ivals[si] = parser.NewDInt(parser.DInt(t)).String()
				case float64:
					ivals[si] = parser.NewDFloat(parser.DFloat(t)).String()
				case string:
					ivals[si] = parser.NewDString(t).String()
				case []byte:
					switch ct := coltypes[cols[si]]; ct {
					case "INTERVAL":
						ivals[si] = fmt.Sprintf("'%s'", t)
					case "BYTES":
						ivals[si] = parser.NewDBytes(parser.DBytes(t)).String()
					default:
						// STRING and DECIMAL types can have optional length
						// suffixes, so only examine the prefix of the type.
						if strings.HasPrefix(coltypes[cols[si]], "STRING") {
							ivals[si] = parser.NewDString(string(t)).String()
						} else if strings.HasPrefix(coltypes[cols[si]], "DECIMAL") {
							ivals[si] = string(t)
						} else {
							panic(errors.Errorf("unknown []byte type: %s, %v: %s", t, cols[si], coltypes[cols[si]]))
						}
					}
				case time.Time:
					var d parser.Datum
					ct := coltypes[cols[si]]
					switch ct {
					case "DATE":
						d = parser.NewDDateFromTime(t, time.UTC)
					case "TIMESTAMP":
						d = parser.MakeDTimestamp(t, time.Nanosecond)
					case "TIMESTAMP WITH TIME ZONE":
						d = parser.MakeDTimestampTZ(t, time.Nanosecond)
					default:
						panic(errors.Errorf("unknown timestamp type: %s, %v: %s", t, cols[si], coltypes[cols[si]]))
					}
					ivals[si] = fmt.Sprintf("'%s'", d)
				default:
					panic(errors.Errorf("unknown field type: %T (%s)", t, cols[si]))
				}
			}
			inserts = append(inserts, ivals)
			i++
		}
		for si, sv := range pk {
			b, ok := sv.([]byte)
			if ok && strings.HasPrefix(coltypes[pkcols[si]], "STRING") {
				// Primary key strings need to be converted to a go string, but not SQL
				// encoded since they aren't being written to a text file.
				pk[si] = string(b)
			}
		}
		if err := rows.Close(); err != nil {
			return err
		}
		if i == 0 {
			break
		}
		fmt.Fprintf(w, "\nINSERT INTO %s VALUES", tablename)
		for idx, values := range inserts {
			if idx > 0 {
				fmt.Fprint(w, ",")
			}
			fmt.Fprint(w, "\n\t(")
			for vi, v := range values {
				if vi > 0 {
					fmt.Fprint(w, ", ")
				}
				fmt.Fprint(w, v)
			}
			fmt.Fprint(w, ")")
		}
		fmt.Fprintln(w, ";")
		if i < limit {
			break
		}
	}
	return nil
}
예제 #2
0
// decodeOidDatum decodes bytes with specified Oid and format code into
// a datum.
func decodeOidDatum(id oid.Oid, code formatCode, b []byte) (parser.Datum, error) {
	var d parser.Datum
	switch id {
	case oid.T_bool:
		switch code {
		case formatText:
			v, err := strconv.ParseBool(string(b))
			if err != nil {
				return d, err
			}
			d = parser.MakeDBool(parser.DBool(v))
		case formatBinary:
			switch b[0] {
			case 0:
				d = parser.MakeDBool(false)
			case 1:
				d = parser.MakeDBool(true)
			default:
				return d, errors.Errorf("unsupported binary bool: %q", b)
			}
		default:
			return d, errors.Errorf("unsupported bool format code: %d", code)
		}
	case oid.T_int2:
		switch code {
		case formatText:
			i, err := strconv.ParseInt(string(b), 10, 64)
			if err != nil {
				return d, err
			}
			d = parser.NewDInt(parser.DInt(i))
		case formatBinary:
			if len(b) < 2 {
				return d, errors.Errorf("int2 requires 2 bytes for binary format")
			}
			i := int16(binary.BigEndian.Uint16(b))
			d = parser.NewDInt(parser.DInt(i))
		default:
			return d, errors.Errorf("unsupported int2 format code: %d", code)
		}
	case oid.T_int4:
		switch code {
		case formatText:
			i, err := strconv.ParseInt(string(b), 10, 64)
			if err != nil {
				return d, err
			}
			d = parser.NewDInt(parser.DInt(i))
		case formatBinary:
			if len(b) < 4 {
				return d, errors.Errorf("int4 requires 4 bytes for binary format")
			}
			i := int32(binary.BigEndian.Uint32(b))
			d = parser.NewDInt(parser.DInt(i))
		default:
			return d, errors.Errorf("unsupported int4 format code: %d", code)
		}
	case oid.T_int8:
		switch code {
		case formatText:
			i, err := strconv.ParseInt(string(b), 10, 64)
			if err != nil {
				return d, err
			}
			d = parser.NewDInt(parser.DInt(i))
		case formatBinary:
			if len(b) < 8 {
				return d, errors.Errorf("int8 requires 8 bytes for binary format")
			}
			i := int64(binary.BigEndian.Uint64(b))
			d = parser.NewDInt(parser.DInt(i))
		default:
			return d, errors.Errorf("unsupported int8 format code: %d", code)
		}
	case oid.T_float4:
		switch code {
		case formatText:
			f, err := strconv.ParseFloat(string(b), 64)
			if err != nil {
				return d, err
			}
			d = parser.NewDFloat(parser.DFloat(f))
		case formatBinary:
			if len(b) < 4 {
				return d, errors.Errorf("float4 requires 4 bytes for binary format")
			}
			f := math.Float32frombits(binary.BigEndian.Uint32(b))
			d = parser.NewDFloat(parser.DFloat(f))
		default:
			return d, errors.Errorf("unsupported float4 format code: %d", code)
		}
	case oid.T_float8:
		switch code {
		case formatText:
			f, err := strconv.ParseFloat(string(b), 64)
			if err != nil {
				return d, err
			}
			d = parser.NewDFloat(parser.DFloat(f))
		case formatBinary:
			if len(b) < 8 {
				return d, errors.Errorf("float8 requires 8 bytes for binary format")
			}
			f := math.Float64frombits(binary.BigEndian.Uint64(b))
			d = parser.NewDFloat(parser.DFloat(f))
		default:
			return d, errors.Errorf("unsupported float8 format code: %d", code)
		}
	case oid.T_numeric:
		switch code {
		case formatText:
			dd := &parser.DDecimal{}
			if _, ok := dd.SetString(string(b)); !ok {
				return nil, errors.Errorf("could not parse string %q as decimal", b)
			}
			d = dd
		case formatBinary:
			r := bytes.NewReader(b)

			alloc := struct {
				pgNum pgNumeric
				i16   int16

				dd parser.DDecimal
			}{}

			for _, ptr := range []interface{}{
				&alloc.pgNum.ndigits,
				&alloc.pgNum.weight,
				&alloc.pgNum.sign,
				&alloc.pgNum.dscale,
			} {
				if err := binary.Read(r, binary.BigEndian, ptr); err != nil {
					return d, err
				}
			}

			if alloc.pgNum.ndigits > 0 {
				decDigits := make([]byte, 0, alloc.pgNum.ndigits*pgDecDigits)
				nextDigit := func() error {
					if err := binary.Read(r, binary.BigEndian, &alloc.i16); err != nil {
						return err
					}
					numZeroes := pgDecDigits
					for i16 := alloc.i16; i16 > 0; i16 /= 10 {
						numZeroes--
					}
					for ; numZeroes > 0; numZeroes-- {
						decDigits = append(decDigits, '0')
					}
					return nil
				}

				for i := int16(0); i < alloc.pgNum.ndigits-1; i++ {
					if err := nextDigit(); err != nil {
						return d, err
					}
					if alloc.i16 > 0 {
						decDigits = strconv.AppendUint(decDigits, uint64(alloc.i16), 10)
					}
				}

				// The last digit may contain padding, which we need to deal with.
				if err := nextDigit(); err != nil {
					return d, err
				}
				dscale := (alloc.pgNum.ndigits - (alloc.pgNum.weight + 1)) * pgDecDigits
				if overScale := dscale - alloc.pgNum.dscale; overScale > 0 {
					dscale -= overScale
					for i := int16(0); i < overScale; i++ {
						alloc.i16 /= 10
					}
				}
				decDigits = strconv.AppendUint(decDigits, uint64(alloc.i16), 10)
				decString := string(decDigits)
				if _, ok := alloc.dd.UnscaledBig().SetString(decString, 10); !ok {
					return nil, errors.Errorf("could not parse string %q as decimal", decString)
				}
				alloc.dd.SetScale(inf.Scale(dscale))
			}

			switch alloc.pgNum.sign {
			case pgNumericPos:
			case pgNumericNeg:
				alloc.dd.Neg(&alloc.dd.Dec)
			default:
				return d, errors.Errorf("unsupported numeric sign: %d", alloc.pgNum.sign)
			}

			d = &alloc.dd
		default:
			return d, errors.Errorf("unsupported numeric format code: %d", code)
		}
	case oid.T_text, oid.T_varchar:
		switch code {
		case formatText, formatBinary:
			d = parser.NewDString(string(b))
		default:
			return d, errors.Errorf("unsupported text format code: %d", code)
		}
	case oid.T_bytea:
		switch code {
		case formatText:
			// http://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667
			// Code cribbed from github.com/lib/pq.

			// We only support hex encoding.
			if len(b) >= 2 && bytes.Equal(b[:2], []byte("\\x")) {
				b = b[2:] // trim off leading "\\x"
				result := make([]byte, hex.DecodedLen(len(b)))
				_, err := hex.Decode(result, b)
				if err != nil {
					return d, err
				}
				d = parser.NewDBytes(parser.DBytes(result))
			} else {
				return d, errors.Errorf("unsupported bytea encoding: %q", b)
			}
		case formatBinary:
			d = parser.NewDBytes(parser.DBytes(b))
		default:
			return d, errors.Errorf("unsupported bytea format code: %d", code)
		}
	case oid.T_timestamp:
		switch code {
		case formatText:
			ts, err := parseTs(string(b))
			if err != nil {
				return d, errors.Errorf("could not parse string %q as timestamp", b)
			}
			d = parser.MakeDTimestamp(ts, time.Microsecond)
		case formatBinary:
			if len(b) < 8 {
				return d, errors.Errorf("timestamp requires 8 bytes for binary format")
			}
			i := int64(binary.BigEndian.Uint64(b))
			d = parser.MakeDTimestamp(pgBinaryToTime(i), time.Microsecond)
		default:
			return d, errors.Errorf("unsupported timestamp format code: %d", code)
		}
	case oid.T_timestamptz:
		switch code {
		case formatText:
			ts, err := parseTs(string(b))
			if err != nil {
				return d, errors.Errorf("could not parse string %q as timestamp", b)
			}
			d = parser.MakeDTimestampTZ(ts, time.Microsecond)
		case formatBinary:
			if len(b) < 8 {
				return d, errors.Errorf("timestamptz requires 8 bytes for binary format")
			}
			i := int64(binary.BigEndian.Uint64(b))
			d = parser.MakeDTimestampTZ(pgBinaryToTime(i), time.Microsecond)
		default:
			return d, errors.Errorf("unsupported timestamptz format code: %d", code)
		}
	case oid.T_date:
		switch code {
		case formatText:
			ts, err := parseTs(string(b))
			if err != nil {
				res, err := parser.ParseDDate(string(b), time.UTC)
				if err != nil {
					return d, errors.Errorf("could not parse string %q as date", b)
				}
				d = res
			} else {
				daysSinceEpoch := ts.Unix() / secondsInDay
				d = parser.NewDDate(parser.DDate(daysSinceEpoch))
			}
		case formatBinary:
			if len(b) < 4 {
				return d, errors.Errorf("date requires 4 bytes for binary format")
			}
			i := int32(binary.BigEndian.Uint32(b))
			d = pgBinaryToDate(i)
		default:
			return d, errors.Errorf("unsupported date format code: %d", code)
		}
	case oid.T_interval:
		switch code {
		case formatText:
			d, err := parser.ParseDInterval(string(b))
			if err != nil {
				return d, errors.Errorf("could not parse string %q as interval", b)
			}
			return d, nil
		default:
			return d, errors.Errorf("unsupported interval format code: %d", code)
		}
	default:
		return d, errors.Errorf("unsupported OID: %v", id)
	}
	return d, nil
}