Ejemplo n.º 1
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	dims, dataType, err := dimsAndDataTypeOf(val.Type())
	if err != nil {
		return nil, err
	}
	// TODO(ashankar): Remove the bytes.Buffer and endcode directly into
	// C-memory, avoiding the memcpy and cutting down memory usage in half.
	shape := make([]int64, dims)
	buf := new(bytes.Buffer)
	if err := encodeTensor(buf, shape, val); err != nil {
		return nil, err
	}
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(buf.Len())),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	if buf.Len() > 0 {
		slice := buf.Bytes() // https://github.com/golang/go/issues/14210
		C.memcpy(C.TF_TensorData(t.c), unsafe.Pointer(&slice[0]), C.size_t(buf.Len()))
	}
	return t, nil
}
Ejemplo n.º 2
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	shape, dataType, err := shapeAndDataTypeOf(val)
	if err != nil {
		return nil, err
	}
	if dataType == String {
		// TODO(ashankar): Handle this
		return nil, fmt.Errorf("String Tensors are not currently supported")
	}
	nbytes := byteSizeOf(dataType, shape)
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	buf := bytes.NewBuffer(raw[:0:len(raw)])
	if err := encodeTensor(buf, val); err != nil {
		return nil, err
	}
	if uintptr(buf.Len()) != nbytes {
		return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version())
	}
	return t, nil
}
Ejemplo n.º 3
0
// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of
// the *C.TF_Tensor, either by passing ownership to the C API or explicitly
// calling C.TF_DeleteTensor() on it.
func (t *Tensor) c() *C.TF_Tensor {
	var shapePtr *C.int64_t
	if len(t.shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0]))
	}
	tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len()))
	if t.buf.Len() > 0 {
		slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210
		C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len()))
	}
	return tensor
}
Ejemplo n.º 4
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	shape, dataType, err := shapeAndDataTypeOf(val)
	if err != nil {
		return nil, err
	}
	nflattened := numElements(shape)
	nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
	if dataType == String {
		// TF_STRING tensors are encoded as an array of 8-byte offsets
		// followed by string data. See c_api.h.
		nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
	}
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	buf := bytes.NewBuffer(raw[:0:len(raw)])
	if dataType != String {
		if err := encodeTensor(buf, val); err != nil {
			return nil, err
		}
		if uintptr(buf.Len()) != nbytes {
			return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
		}
	} else {
		e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
		if e.encode(reflect.ValueOf(value)); err != nil {
			return nil, err
		}
		if int64(buf.Len()) != nflattened*8 {
			return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
		}
	}
	return t, nil
}
Ejemplo n.º 5
0
// ReadTensor constructs a Tensor with the provided type and shape from the
// serialized tensor contents in r.
//
// See also WriteContentsTo.
func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
	if err := isTensorSerializable(dataType); err != nil {
		return nil, err
	}
	nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	n, err := r.Read(raw)
	if err != nil {
		return nil, err
	}
	if uintptr(n) != nbytes {
		return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
	}
	return t, nil
}
Ejemplo n.º 6
0
func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
	cAttrName := C.CString(name)
	defer C.free(unsafe.Pointer(cAttrName))
	switch value := value.(type) {
	case string:
		cstr := C.CString(value)
		C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value)))
		C.free(unsafe.Pointer(cstr))
	case []string:
		size := len(value)
		list := make([]unsafe.Pointer, size)
		lens := make([]C.size_t, size)
		for i, s := range value {
			list[i] = unsafe.Pointer(C.CString(s))
			lens[i] = C.size_t(len(s))
		}
		C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
		for _, s := range list {
			C.free(s)
		}
	case int64:
		C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
	case []int64:
		size := len(value)
		list := make([]C.int64_t, size)
		for i, v := range value {
			list[i] = C.int64_t(v)
		}
		C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
	case float32:
		C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
	case []float32:
		size := len(value)
		list := make([]C.float, size)
		for i, v := range value {
			list[i] = C.float(v)
		}
		C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
	case bool:
		v := C.uchar(0)
		if value {
			v = 1
		}
		C.TF_SetAttrBool(cdesc, cAttrName, v)
	case []bool:
		size := len(value)
		list := make([]C.uchar, size)
		for i, v := range value {
			if v {
				list[i] = 1
			}
		}
		C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
	case DataType:
		C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
	case []DataType:
		list := (*C.TF_DataType)(&value[0])
		C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
	case *Tensor:
		C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c)
		if err := status.Err(); err != nil {
			return fmt.Errorf("bad value for attribute %q: %v", name, err)
		}
	case []*Tensor:
		size := len(value)
		list := make([]*C.TF_Tensor, size)
		for i, v := range value {
			list[i] = v.c
		}
		C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c)
		if err := status.Err(); err != nil {
			return fmt.Errorf("bad value for attribute %q: %v", name, err)
		}
	case Shape:
		ndims, dims := cshape(value)
		var dimsp *C.int64_t
		if ndims > 0 {
			dimsp = &dims[0]
		}
		C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
	case []Shape:
		ndims := make([]C.int, len(value))
		dims := make([][]C.int64_t, len(value))
		dimsp := make([]*C.int64_t, len(value))
		for i, s := range value {
			ndims[i], dims[i] = cshape(s)
			if ndims[i] > 0 {
				dimsp[i] = &dims[i][0]
			}
		}
		C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
	default:
		return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
	}
	return nil
}
Ejemplo n.º 7
0
func (b *opBuilder) SetAttrType(name string, typ DataType) {
	attrName := C.CString(name)
	C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ))
	C.free(unsafe.Pointer(attrName))
}
Ejemplo n.º 8
0
func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
	cAttrName := C.CString(name)
	defer C.free(unsafe.Pointer(cAttrName))
	switch value := value.(type) {
	case string:
		cstr := C.CString(value)
		C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value)))
		C.free(unsafe.Pointer(cstr))
	case []string:
		size := len(value)
		list := make([]unsafe.Pointer, size)
		lens := make([]C.size_t, size)
		for i, s := range value {
			list[i] = unsafe.Pointer(C.CString(s))
			lens[i] = C.size_t(len(s))
		}
		C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
		for _, s := range list {
			C.free(s)
		}
	case int64:
		C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
	case []int64:
		size := len(value)
		list := make([]C.int64_t, size)
		for i, v := range value {
			list[i] = C.int64_t(v)
		}
		C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
	case float32:
		C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
	case []float32:
		size := len(value)
		list := make([]C.float, size)
		for i, v := range value {
			list[i] = C.float(v)
		}
		C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
	case bool:
		v := C.uchar(0)
		if value {
			v = 1
		}
		C.TF_SetAttrBool(cdesc, cAttrName, v)
	case []bool:
		size := len(value)
		list := make([]C.uchar, size)
		for i, v := range value {
			if v {
				list[i] = 1
			}
		}
		C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
	case DataType:
		C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
	case []DataType:
		list := (*C.TF_DataType)(&value[0])
		C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
	case *Tensor:
		C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c)
		if err := status.Err(); err != nil {
			return fmt.Errorf("bad value for attribute %q: %v", name, err)
		}
	case []*Tensor:
		size := len(value)
		list := make([]*C.TF_Tensor, size)
		for i, v := range value {
			list[i] = v.c
		}
		C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c)
		if err := status.Err(); err != nil {
			return fmt.Errorf("bad value for attribute %q: %v", name, err)
		}
	default:
		// Shapes can be done, but will require that it be
		// distinguishable from []int64. Which is fine, it
		// probably makes sense to define a Shape type anyway,
		// since that should handle partially known shapes as
		// well and hide the special meaning of -1?
		return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
	}
	return nil
}