Ejemplo n.º 1
0
// deleteCTensor only exists to delete C.TF_Tensors in tests. go test doesn't
// support cgo.
func deleteCTensor(ct *C.TF_Tensor) {
	C.TF_DeleteTensor(ct)
}
Ejemplo n.º 2
0
func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) }
Ejemplo n.º 3
0
// Run the graph with the associated session starting with the supplied inputs.
// inputs and outputs may be set to nil. Runs, but does not return Tensors
// for operations specified in targets.
//
// On success, returns the Tensor outputs in the same order as supplied in
// the outputs argument. If outputs is set to nil, the returned Tensor outputs
// is empty.
func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) {
	s.mu.Lock()
	if s.c == nil {
		s.mu.Unlock()
		return nil, errors.New("session is closed")
	}
	s.wg.Add(1)
	s.mu.Unlock()
	defer s.wg.Done()

	var inputPorts []C.TF_Port
	var inputValues []*C.TF_Tensor
	if inputs != nil {
		for port, tensor := range inputs {
			inputPorts = append(inputPorts, port.c())
			inputValues = append(inputValues, tensor.c())
		}
	}

	var outputPorts []C.TF_Port
	for _, port := range outputs {
		outputPorts = append(outputPorts, port.c())
	}
	outputValues := make([]*C.TF_Tensor, len(outputs))
	var cTargets []*C.TF_Operation
	for _, target := range targets {
		cTargets = append(cTargets, target.c)
	}

	status := newStatus()
	var inputPortsPtr *C.TF_Port
	var inputValuesPtr **C.TF_Tensor
	if len(inputPorts) > 0 {
		inputPortsPtr = &inputPorts[0]
		inputValuesPtr = &inputValues[0]
	}

	var outputPortsPtr *C.TF_Port
	var outputValuesPtr **C.TF_Tensor
	if len(outputPorts) > 0 {
		outputPortsPtr = &outputPorts[0]
		outputValuesPtr = &outputValues[0]
	}

	var cTargetsPtr **C.TF_Operation
	if len(cTargets) > 0 {
		cTargetsPtr = &cTargets[0]
	}

	C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
	if err := status.Err(); err != nil {
		return nil, err
	}

	var tensors []*Tensor
	for _, val := range outputValues {
		tensors = append(tensors, newTensorFromC(val))
		C.TF_DeleteTensor(val)
	}

	return tensors, nil
}