Ejemplo n.º 1
0
func floatOrDecimalBuiltin1(f func(float64) (Datum, error)) []builtin {
	return []builtin{
		{
			types:      argTypes{floatType},
			returnType: typeFloat,
			fn: func(_ EvalContext, args DTuple) (Datum, error) {
				return f(float64(args[0].(DFloat)))
			},
		}, {
			types:      argTypes{decimalType},
			returnType: typeDecimal,
			fn: func(_ EvalContext, args DTuple) (Datum, error) {
				dec := args[0].(*DDecimal)
				v, err := decimal.Float64FromDec(&dec.Dec)
				if err != nil {
					return nil, err
				}
				r, err := f(v)
				if err != nil {
					return r, err
				}
				rf := float64(r.(DFloat))
				if math.IsNaN(rf) || math.IsInf(rf, 0) {
					// TODO(nvanbenschoten) NaN semmantics should be introduced
					// into the decimal library to support it here.
					return nil, fmt.Errorf("decimal does not support NaN")
				}
				dd := &DDecimal{}
				decimal.SetFromFloat(&dd.Dec, rf)
				return dd, nil
			},
		},
	}
}
Ejemplo n.º 2
0
// Eval implements the Expr interface.
func (expr *CastExpr) Eval(ctx EvalContext) (Datum, error) {
	d, err := expr.Expr.Eval(ctx)
	if err != nil {
		return nil, err
	}

	// NULL cast to anything is NULL.
	if d == DNull {
		return d, nil
	}

	switch expr.Type.(type) {
	case *BoolType:
		switch v := d.(type) {
		case DBool:
			return d, nil
		case DInt:
			return DBool(v != 0), nil
		case DFloat:
			return DBool(v != 0), nil
		case *DDecimal:
			return DBool(v.Sign() != 0), nil
		case DString:
			// TODO(pmattis): strconv.ParseBool is more permissive than the SQL
			// spec. Is that ok?
			b, err := strconv.ParseBool(string(v))
			if err != nil {
				return nil, err
			}
			return DBool(b), nil
		}

	case *IntType:
		switch v := d.(type) {
		case DBool:
			if v {
				return DInt(1), nil
			}
			return DInt(0), nil
		case DInt:
			return d, nil
		case DFloat:
			f, err := round(float64(v), 0)
			if err != nil {
				panic(fmt.Sprintf("round should never fail with digits hardcoded to 0: %s", err))
			}
			return DInt(f.(DFloat)), nil
		case *DDecimal:
			dec := new(inf.Dec)
			dec.Round(&v.Dec, 0, inf.RoundHalfUp)
			i, ok := dec.Unscaled()
			if !ok {
				return nil, errIntOutOfRange
			}
			return DInt(i), nil
		case DString:
			i, err := strconv.ParseInt(string(v), 0, 64)
			if err != nil {
				return nil, err
			}
			return DInt(i), nil
		}

	case *FloatType:
		switch v := d.(type) {
		case DBool:
			if v {
				return DFloat(1), nil
			}
			return DFloat(0), nil
		case DInt:
			return DFloat(v), nil
		case DFloat:
			return d, nil
		case *DDecimal:
			f, err := decimal.Float64FromDec(&v.Dec)
			if err != nil {
				return nil, errFloatOutOfRange
			}
			return DFloat(f), nil
		case DString:
			f, err := strconv.ParseFloat(string(v), 64)
			if err != nil {
				return nil, err
			}
			return DFloat(f), nil
		}

	case *DecimalType:
		dd := &DDecimal{}
		switch v := d.(type) {
		case DBool:
			if v {
				dd.SetUnscaled(1)
			}
			return dd, nil
		case DInt:
			dd.SetUnscaled(int64(v))
			return dd, nil
		case DFloat:
			decimal.SetFromFloat(&dd.Dec, float64(v))
			return dd, nil
		case *DDecimal:
			return d, nil
		case DString:
			if _, ok := dd.SetString(string(v)); !ok {
				return nil, fmt.Errorf("could not parse string %q as decimal", v)
			}
			return dd, nil
		}

	case *StringType:
		var s DString
		switch t := d.(type) {
		case DBool, DInt, DFloat, *DDecimal, dNull:
			s = DString(d.String())
		case DString:
			s = t
		case DBytes:
			if !utf8.ValidString(string(t)) {
				return nil, fmt.Errorf("invalid utf8: %q", string(t))
			}
			s = DString(t)
		}
		if c, ok := expr.Type.(*StringType); ok {
			// If the CHAR type specifies a limit we truncate to that limit:
			//   'hello'::CHAR(2) -> 'he'
			if c.N > 0 && c.N < len(s) {
				s = s[:c.N]
			}
		}
		return s, nil

	case *BytesType:
		switch t := d.(type) {
		case DString:
			return DBytes(t), nil
		case DBytes:
			return d, nil
		}

	case *DateType:
		switch d := d.(type) {
		case DString:
			return ParseDate(d)
		case DTimestamp:
			return ctx.makeDDate(d.Time)
		}

	case *TimestampType:
		switch d := d.(type) {
		case DString:
			return ctx.ParseTimestamp(d)
		case DDate:
			loc, err := ctx.GetLocation()
			if err != nil {
				return nil, err
			}
			year, month, day := time.Unix(int64(d)*secondsInDay, 0).UTC().Date()
			return DTimestamp{Time: time.Date(year, month, day, 0, 0, 0, 0, loc)}, nil
		}

	case *IntervalType:
		switch d.(type) {
		case DString:
			// We use the Golang format for specifying duration.
			// TODO(vivek): we might consider using the postgres format as well.
			d, err := time.ParseDuration(string(d.(DString)))
			return DInterval{Duration: d}, err

		case DInt:
			// An integer duration represents a duration in nanoseconds.
			return DInterval{Duration: time.Duration(d.(DInt))}, nil
		}
	}

	return nil, fmt.Errorf("invalid cast: %s -> %s", d.Type(), expr.Type)
}