예제 #1
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	shape, dataType, err := shapeAndDataTypeOf(val)
	if err != nil {
		return nil, err
	}
	if dataType == String {
		// TODO(ashankar): Handle this
		return nil, fmt.Errorf("String Tensors are not currently supported")
	}
	nbytes := byteSizeOf(dataType, shape)
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	buf := bytes.NewBuffer(raw[:0:len(raw)])
	if err := encodeTensor(buf, val); err != nil {
		return nil, err
	}
	if uintptr(buf.Len()) != nbytes {
		return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version())
	}
	return t, nil
}
예제 #2
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	dims, dataType, err := dimsAndDataTypeOf(val.Type())
	if err != nil {
		return nil, err
	}
	// TODO(ashankar): Remove the bytes.Buffer and endcode directly into
	// C-memory, avoiding the memcpy and cutting down memory usage in half.
	shape := make([]int64, dims)
	buf := new(bytes.Buffer)
	if err := encodeTensor(buf, shape, val); err != nil {
		return nil, err
	}
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(buf.Len())),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	if buf.Len() > 0 {
		slice := buf.Bytes() // https://github.com/golang/go/issues/14210
		C.memcpy(C.TF_TensorData(t.c), unsafe.Pointer(&slice[0]), C.size_t(buf.Len()))
	}
	return t, nil
}
예제 #3
0
// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of
// the *C.TF_Tensor, either by passing ownership to the C API or explicitly
// calling C.TF_DeleteTensor() on it.
func (t *Tensor) c() *C.TF_Tensor {
	var shapePtr *C.int64_t
	if len(t.shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0]))
	}
	tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len()))
	if t.buf.Len() > 0 {
		slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210
		C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len()))
	}
	return tensor
}
예제 #4
0
// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
// slices, and arrays. Every element of a slice must have the same length so
// that the resulting Tensor has a valid shape.
func NewTensor(value interface{}) (*Tensor, error) {
	val := reflect.ValueOf(value)
	shape, dataType, err := shapeAndDataTypeOf(val)
	if err != nil {
		return nil, err
	}
	nflattened := numElements(shape)
	nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
	if dataType == String {
		// TF_STRING tensors are encoded as an array of 8-byte offsets
		// followed by string data. See c_api.h.
		nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
	}
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	buf := bytes.NewBuffer(raw[:0:len(raw)])
	if dataType != String {
		if err := encodeTensor(buf, val); err != nil {
			return nil, err
		}
		if uintptr(buf.Len()) != nbytes {
			return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
		}
	} else {
		e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
		if e.encode(reflect.ValueOf(value)); err != nil {
			return nil, err
		}
		if int64(buf.Len()) != nflattened*8 {
			return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
		}
	}
	return t, nil
}
예제 #5
0
// ReadTensor constructs a Tensor with the provided type and shape from the
// serialized tensor contents in r.
//
// See also WriteContentsTo.
func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
	if err := isTensorSerializable(dataType); err != nil {
		return nil, err
	}
	nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
	var shapePtr *C.int64_t
	if len(shape) > 0 {
		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
	}
	t := &Tensor{
		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
		shape: shape,
	}
	runtime.SetFinalizer(t, (*Tensor).finalize)
	raw := tensorData(t.c)
	n, err := r.Read(raw)
	if err != nil {
		return nil, err
	}
	if uintptr(n) != nbytes {
		return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
	}
	return t, nil
}