예제 #1
0
// Shape returns the (possibly incomplete) shape of the tensor produced p.
func (p Output) Shape() Shape {
	status := newStatus()
	port := p.c()
	ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
	if err := status.Err(); err != nil {
		// This should not be possible since an error only occurs if
		// the operation does not belong to the graph.  It should not
		// be possible to construct such an Operation object.
		return Shape{}
	}
	if ndims < 0 {
		return Shape{}
	}
	if ndims == 0 {
		return ScalarShape()
	}
	dims := make([]C.int64_t, ndims)
	C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
	if err := status.Err(); err != nil {
		// Same as above, should not be possible.
		return Shape{}
	}
	ret := Shape{dims: make([]int64, ndims)}
	for i := 0; i < int(ndims); i++ {
		ret.dims[i] = int64(dims[i])
	}
	return ret
}
예제 #2
0
// Shape returns the (possibly incomplete) shape of the tensor produced p.
//
// Returns a slice of length 0 if the tensor is a scalar.  Returns a slice
// where shape[i] is the size of the i-th dimension of the tensor, or -1 if the
// size of that dimension is not known.
//
// Returns an error if the number of dimensions of the tensor is not known.
func (p Output) Shape() (shape []int64, err error) {
	status := newStatus()
	port := p.c()
	ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
	if err := status.Err(); err != nil {
		return nil, err
	}
	if ndims < 0 {
		return nil, errors.New("unknown number of dimensions")
	}
	if ndims == 0 {
		return nil, nil
	}
	dims := make([]C.int64_t, ndims)
	C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
	if err := status.Err(); err != nil {
		return nil, err
	}
	ret := make([]int64, ndims)
	for i := 0; i < int(ndims); i++ {
		ret[i] = int64(dims[i])
	}
	return ret, nil
}