// NewTensor converts from a Go value to a Tensor. Valid values are scalars, // slices, and arrays. Every element of a slice must have the same length so // that the resulting Tensor has a valid shape. func NewTensor(value interface{}) (*Tensor, error) { val := reflect.ValueOf(value) dims, dataType, err := dimsAndDataTypeOf(val.Type()) if err != nil { return nil, err } // TODO(ashankar): Remove the bytes.Buffer and endcode directly into // C-memory, avoiding the memcpy and cutting down memory usage in half. shape := make([]int64, dims) buf := new(bytes.Buffer) if err := encodeTensor(buf, shape, val); err != nil { return nil, err } var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) } t := &Tensor{ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(buf.Len())), shape: shape, } runtime.SetFinalizer(t, (*Tensor).finalize) if buf.Len() > 0 { slice := buf.Bytes() // https://github.com/golang/go/issues/14210 C.memcpy(C.TF_TensorData(t.c), unsafe.Pointer(&slice[0]), C.size_t(buf.Len())) } return t, nil }
// NewTensor converts from a Go value to a Tensor. Valid values are scalars, // slices, and arrays. Every element of a slice must have the same length so // that the resulting Tensor has a valid shape. func NewTensor(value interface{}) (*Tensor, error) { val := reflect.ValueOf(value) shape, dataType, err := shapeAndDataTypeOf(val) if err != nil { return nil, err } if dataType == String { // TODO(ashankar): Handle this return nil, fmt.Errorf("String Tensors are not currently supported") } nbytes := byteSizeOf(dataType, shape) var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) } t := &Tensor{ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), shape: shape, } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) buf := bytes.NewBuffer(raw[:0:len(raw)]) if err := encodeTensor(buf, val); err != nil { return nil, err } if uintptr(buf.Len()) != nbytes { return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version()) } return t, nil }
// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of // the *C.TF_Tensor, either by passing ownership to the C API or explicitly // calling C.TF_DeleteTensor() on it. func (t *Tensor) c() *C.TF_Tensor { var shapePtr *C.int64_t if len(t.shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0])) } tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len())) if t.buf.Len() > 0 { slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210 C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len())) } return tensor }
// NewTensor converts from a Go value to a Tensor. Valid values are scalars, // slices, and arrays. Every element of a slice must have the same length so // that the resulting Tensor has a valid shape. func NewTensor(value interface{}) (*Tensor, error) { val := reflect.ValueOf(value) shape, dataType, err := shapeAndDataTypeOf(val) if err != nil { return nil, err } nflattened := numElements(shape) nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened) if dataType == String { // TF_STRING tensors are encoded as an array of 8-byte offsets // followed by string data. See c_api.h. nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value) } var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) } t := &Tensor{ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), shape: shape, } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) buf := bytes.NewBuffer(raw[:0:len(raw)]) if dataType != String { if err := encodeTensor(buf, val); err != nil { return nil, err } if uintptr(buf.Len()) != nbytes { return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) } } else { e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()} if e.encode(reflect.ValueOf(value)); err != nil { return nil, err } if int64(buf.Len()) != nflattened*8 { return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8) } } return t, nil }
// ReadTensor constructs a Tensor with the provided type and shape from the // serialized tensor contents in r. // // See also WriteContentsTo. func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { if err := isTensorSerializable(dataType); err != nil { return nil, err } nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) } t := &Tensor{ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), shape: shape, } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) n, err := r.Read(raw) if err != nil { return nil, err } if uintptr(n) != nbytes { return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) } return t, nil }
func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { cAttrName := C.CString(name) defer C.free(unsafe.Pointer(cAttrName)) switch value := value.(type) { case string: cstr := C.CString(value) C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value))) C.free(unsafe.Pointer(cstr)) case []string: size := len(value) list := make([]unsafe.Pointer, size) lens := make([]C.size_t, size) for i, s := range value { list[i] = unsafe.Pointer(C.CString(s)) lens[i] = C.size_t(len(s)) } C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) for _, s := range list { C.free(s) } case int64: C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) case []int64: size := len(value) list := make([]C.int64_t, size) for i, v := range value { list[i] = C.int64_t(v) } C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) case float32: C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) case []float32: size := len(value) list := make([]C.float, size) for i, v := range value { list[i] = C.float(v) } C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) case bool: v := C.uchar(0) if value { v = 1 } C.TF_SetAttrBool(cdesc, cAttrName, v) case []bool: size := len(value) list := make([]C.uchar, size) for i, v := range value { if v { list[i] = 1 } } C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) case DataType: C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) case []DataType: list := (*C.TF_DataType)(&value[0]) C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) case *Tensor: C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } case []*Tensor: size := len(value) list := make([]*C.TF_Tensor, size) for i, v := range value { list[i] = v.c } C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c) if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } case Shape: ndims, dims := cshape(value) var dimsp *C.int64_t if ndims > 0 { dimsp = &dims[0] } C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) case []Shape: ndims := make([]C.int, len(value)) dims := make([][]C.int64_t, len(value)) dimsp := make([]*C.int64_t, len(value)) for i, s := range value { ndims[i], dims[i] = cshape(s) if ndims[i] > 0 { dimsp[i] = &dims[i][0] } } C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) default: return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil }
func (b *opBuilder) SetAttrType(name string, typ DataType) { attrName := C.CString(name) C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ)) C.free(unsafe.Pointer(attrName)) }
func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { cAttrName := C.CString(name) defer C.free(unsafe.Pointer(cAttrName)) switch value := value.(type) { case string: cstr := C.CString(value) C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.size_t(len(value))) C.free(unsafe.Pointer(cstr)) case []string: size := len(value) list := make([]unsafe.Pointer, size) lens := make([]C.size_t, size) for i, s := range value { list[i] = unsafe.Pointer(C.CString(s)) lens[i] = C.size_t(len(s)) } C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) for _, s := range list { C.free(s) } case int64: C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) case []int64: size := len(value) list := make([]C.int64_t, size) for i, v := range value { list[i] = C.int64_t(v) } C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) case float32: C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) case []float32: size := len(value) list := make([]C.float, size) for i, v := range value { list[i] = C.float(v) } C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) case bool: v := C.uchar(0) if value { v = 1 } C.TF_SetAttrBool(cdesc, cAttrName, v) case []bool: size := len(value) list := make([]C.uchar, size) for i, v := range value { if v { list[i] = 1 } } C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) case DataType: C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) case []DataType: list := (*C.TF_DataType)(&value[0]) C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) case *Tensor: C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } case []*Tensor: size := len(value) list := make([]*C.TF_Tensor, size) for i, v := range value { list[i] = v.c } C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c) if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } default: // Shapes can be done, but will require that it be // distinguishable from []int64. Which is fine, it // probably makes sense to define a Shape type anyway, // since that should handle partially known shapes as // well and hide the special meaning of -1? return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil }