func floatOrDecimalBuiltin1(f func(float64) (Datum, error)) []builtin { return []builtin{ { types: argTypes{floatType}, returnType: typeFloat, fn: func(_ EvalContext, args DTuple) (Datum, error) { return f(float64(args[0].(DFloat))) }, }, { types: argTypes{decimalType}, returnType: typeDecimal, fn: func(_ EvalContext, args DTuple) (Datum, error) { dec := args[0].(*DDecimal) v, err := decimal.Float64FromDec(&dec.Dec) if err != nil { return nil, err } r, err := f(v) if err != nil { return r, err } rf := float64(r.(DFloat)) if math.IsNaN(rf) || math.IsInf(rf, 0) { // TODO(nvanbenschoten) NaN semmantics should be introduced // into the decimal library to support it here. return nil, fmt.Errorf("decimal does not support NaN") } dd := &DDecimal{} decimal.SetFromFloat(&dd.Dec, rf) return dd, nil }, }, } }
// Eval implements the Expr interface. func (expr *CastExpr) Eval(ctx EvalContext) (Datum, error) { d, err := expr.Expr.Eval(ctx) if err != nil { return nil, err } // NULL cast to anything is NULL. if d == DNull { return d, nil } switch expr.Type.(type) { case *BoolType: switch v := d.(type) { case DBool: return d, nil case DInt: return DBool(v != 0), nil case DFloat: return DBool(v != 0), nil case *DDecimal: return DBool(v.Sign() != 0), nil case DString: // TODO(pmattis): strconv.ParseBool is more permissive than the SQL // spec. Is that ok? b, err := strconv.ParseBool(string(v)) if err != nil { return nil, err } return DBool(b), nil } case *IntType: switch v := d.(type) { case DBool: if v { return DInt(1), nil } return DInt(0), nil case DInt: return d, nil case DFloat: f, err := round(float64(v), 0) if err != nil { panic(fmt.Sprintf("round should never fail with digits hardcoded to 0: %s", err)) } return DInt(f.(DFloat)), nil case *DDecimal: dec := new(inf.Dec) dec.Round(&v.Dec, 0, inf.RoundHalfUp) i, ok := dec.Unscaled() if !ok { return nil, errIntOutOfRange } return DInt(i), nil case DString: i, err := strconv.ParseInt(string(v), 0, 64) if err != nil { return nil, err } return DInt(i), nil } case *FloatType: switch v := d.(type) { case DBool: if v { return DFloat(1), nil } return DFloat(0), nil case DInt: return DFloat(v), nil case DFloat: return d, nil case *DDecimal: f, err := decimal.Float64FromDec(&v.Dec) if err != nil { return nil, errFloatOutOfRange } return DFloat(f), nil case DString: f, err := strconv.ParseFloat(string(v), 64) if err != nil { return nil, err } return DFloat(f), nil } case *DecimalType: dd := &DDecimal{} switch v := d.(type) { case DBool: if v { dd.SetUnscaled(1) } return dd, nil case DInt: dd.SetUnscaled(int64(v)) return dd, nil case DFloat: decimal.SetFromFloat(&dd.Dec, float64(v)) return dd, nil case *DDecimal: return d, nil case DString: if _, ok := dd.SetString(string(v)); !ok { return nil, fmt.Errorf("could not parse string %q as decimal", v) } return dd, nil } case *StringType: var s DString switch t := d.(type) { case DBool, DInt, DFloat, *DDecimal, dNull: s = DString(d.String()) case DString: s = t case DBytes: if !utf8.ValidString(string(t)) { return nil, fmt.Errorf("invalid utf8: %q", string(t)) } s = DString(t) } if c, ok := expr.Type.(*StringType); ok { // If the CHAR type specifies a limit we truncate to that limit: // 'hello'::CHAR(2) -> 'he' if c.N > 0 && c.N < len(s) { s = s[:c.N] } } return s, nil case *BytesType: switch t := d.(type) { case DString: return DBytes(t), nil case DBytes: return d, nil } case *DateType: switch d := d.(type) { case DString: return ParseDate(d) case DTimestamp: return ctx.makeDDate(d.Time) } case *TimestampType: switch d := d.(type) { case DString: return ctx.ParseTimestamp(d) case DDate: loc, err := ctx.GetLocation() if err != nil { return nil, err } year, month, day := time.Unix(int64(d)*secondsInDay, 0).UTC().Date() return DTimestamp{Time: time.Date(year, month, day, 0, 0, 0, 0, loc)}, nil } case *IntervalType: switch d.(type) { case DString: // We use the Golang format for specifying duration. // TODO(vivek): we might consider using the postgres format as well. d, err := time.ParseDuration(string(d.(DString))) return DInterval{Duration: d}, err case DInt: // An integer duration represents a duration in nanoseconds. return DInterval{Duration: time.Duration(d.(DInt))}, nil } } return nil, fmt.Errorf("invalid cast: %s -> %s", d.Type(), expr.Type) }