// CorrBankFFT computes the correlation of an image with a bank of filters. // h_p[u, v] = (f corr g_p)[u, v] func CorrBankFFT(f *rimg64.Image, g *Bank) (*rimg64.Multi, error) { out := ValidSize(f.Size(), g.Size()) if out.X <= 0 || out.Y <= 0 { return nil, nil } // Determine optimal size for FFT. work, _ := FFT2Size(f.Size()) // Re-use FFT of image. fhat := fftw.NewArray2(work.X, work.Y) copyImageTo(fhat, f) fftw.FFT2To(fhat, fhat) // Transform of each filter. curr := fftw.NewArray2(work.X, work.Y) fwd := fftw.NewPlan2(curr, curr, fftw.Forward, fftw.Estimate) defer fwd.Destroy() bwd := fftw.NewPlan2(curr, curr, fftw.Backward, fftw.Estimate) defer bwd.Destroy() h := rimg64.NewMulti(out.X, out.Y, len(g.Filters)) alpha := complex(1/float64(work.X*work.Y), 0) // For each output channel. for p, gp := range g.Filters { // Take FFT. copyImageTo(curr, gp) fwd.Execute() // h_p[x] = (G_p corr F)[x] // H_p[x] = conj(G_p[x]) F[x] scaleMul(curr, alpha, curr, fhat) bwd.Execute() copyRealToChannel(h, p, curr) } return h, nil }
// CorrMultiBankFFT computes the correlation of // a multi-channel image with a multi-channel filter. // h[u, v] = sum_p (f_p corr g_p)[u, v] func CorrMultiFFT(f, g *rimg64.Multi) (*rimg64.Image, error) { if err := errIfChannelsNotEq(f, g); err != nil { panic(err) } out := ValidSize(f.Size(), g.Size()) if out.Eq(image.ZP) { return nil, nil } work, _ := FFT2Size(f.Size()) fhat := fftw.NewArray2(work.X, work.Y) ghat := fftw.NewArray2(work.X, work.Y) ffwd := fftw.NewPlan2(fhat, fhat, fftw.Forward, fftw.Estimate) defer ffwd.Destroy() gfwd := fftw.NewPlan2(ghat, ghat, fftw.Forward, fftw.Estimate) defer gfwd.Destroy() hhat := fftw.NewArray2(work.X, work.Y) for p := 0; p < f.Channels; p++ { // Take transform of each channel. copyChannelTo(fhat, f, p) ffwd.Execute() copyChannelTo(ghat, g, p) gfwd.Execute() addMul(hhat, ghat, fhat) } n := float64(work.X * work.Y) scale(complex(1/n, 0), hhat) fftw.IFFT2To(hhat, hhat) h := rimg64.New(out.X, out.Y) copyRealTo(h, hhat) return h, nil }
// CorrBankStrideFFT computes the strided correlation of // an image with a bank of filters. // h_p[u, v] = (f corr g_p)[stride*u, stride*v] func CorrBankStrideFFT(f *rimg64.Image, g *Bank, stride int) (*rimg64.Multi, error) { out := ValidSizeStride(f.Size(), g.Size(), stride) if out.X <= 0 || out.Y <= 0 { return nil, nil } // Compute strided convolution as the sum over // a stride x stride grid of small convolutions. grid := image.Pt(stride, stride) // But do not divide into a larger grid than the size of the filter. // If the filter is smaller than the stride, // then some pixels in the image will not affect the output. grid.X = min(grid.X, g.Width) grid.Y = min(grid.Y, g.Height) // Determine the size of the sub-sampled filter. gsub := image.Pt(ceilDiv(g.Width, grid.X), ceilDiv(g.Height, grid.Y)) // The sub-sampled size of the image should be such that // the output size is attained. fsub := image.Pt(out.X+gsub.X-1, out.Y+gsub.Y-1) // Determine optimal size for FFT. work, _ := FFT2Size(fsub) // Cache FFT of image for convolving with multiple filters. // Re-use plan for multiple convolutions too. fhat := fftw.NewArray2(work.X, work.Y) ffwd := fftw.NewPlan2(fhat, fhat, fftw.Forward, fftw.Estimate) defer ffwd.Destroy() // FFT for current filter. ghat := fftw.NewArray2(work.X, work.Y) gfwd := fftw.NewPlan2(ghat, ghat, fftw.Forward, fftw.Estimate) defer gfwd.Destroy() // Allocate one array per output channel. hhat := make([]*fftw.Array2, len(g.Filters)) for k := range hhat { hhat[k] = fftw.NewArray2(work.X, work.Y) } // Normalization factor. alpha := complex(1/float64(work.X*work.Y), 0) // Add the convolutions over channels and strides. for i := 0; i < grid.X; i++ { for j := 0; j < grid.Y; j++ { // Take transform of downsampled image given offset (i, j). copyStrideTo(fhat, f, stride, image.Pt(i, j)) ffwd.Execute() // Take transform of each downsampled channel given offset (i, j). for q := range hhat { copyStrideTo(ghat, g.Filters[q], stride, image.Pt(i, j)) gfwd.Execute() addMul(hhat[q], ghat, fhat) } } } // Take the inverse transform of each channel. h := rimg64.NewMulti(out.X, out.Y, len(g.Filters)) for q := range hhat { scale(alpha, hhat[q]) fftw.IFFT2To(hhat[q], hhat[q]) copyRealToChannel(h, q, hhat[q]) } return h, nil }
// CorrMultiStrideFFT computes the correlation of // a multi-channel image with a multi-channel filter. // h[u, v] = sum_q (f_q corr g_q)[u, v] func CorrMultiStrideFFT(f, g *rimg64.Multi, stride int) (*rimg64.Image, error) { if err := errIfChannelsNotEq(f, g); err != nil { panic(err) } out := ValidSizeStride(f.Size(), g.Size(), stride) if out.X <= 0 || out.Y <= 0 { return nil, nil } // Compute strided convolution as the sum over // a stride x stride grid of small convolutions. grid := image.Pt(stride, stride) // But do not divide into a larger grid than the size of the filter. // If the filter is smaller than the stride, // then some pixels in the image will not affect the output. grid.X = min(grid.X, g.Width) grid.Y = min(grid.Y, g.Height) // Determine the size of the sub-sampled filter. gsub := image.Pt(ceilDiv(g.Width, grid.X), ceilDiv(g.Height, grid.Y)) // The sub-sampled size of the image should be such that // the output size is attained. fsub := image.Pt(out.X+gsub.X-1, out.Y+gsub.Y-1) // Determine optimal size for FFT. work, _ := FFT2Size(fsub) // Cache FFT of each channel of image for convolving with multiple filters. // Re-use plan for multiple convolutions too. fhat := fftw.NewArray2(work.X, work.Y) ffwd := fftw.NewPlan2(fhat, fhat, fftw.Forward, fftw.Estimate) defer ffwd.Destroy() ghat := fftw.NewArray2(work.X, work.Y) gfwd := fftw.NewPlan2(ghat, ghat, fftw.Forward, fftw.Estimate) defer gfwd.Destroy() // Normalization factor. alpha := complex(1/float64(work.X*work.Y), 0) // Add the convolutions over channels and strides. hhat := fftw.NewArray2(work.X, work.Y) for k := 0; k < f.Channels; k++ { for i := 0; i < grid.X; i++ { for j := 0; j < grid.Y; j++ { // Copy each downsampled channel and take its transform. copyChannelStrideTo(fhat, f, k, stride, image.Pt(i, j)) ffwd.Execute() copyChannelStrideTo(ghat, g, k, stride, image.Pt(i, j)) gfwd.Execute() addMul(hhat, ghat, fhat) } } } // Take the inverse transform. h := rimg64.New(out.X, out.Y) scale(alpha, hhat) fftw.IFFT2To(hhat, hhat) copyRealTo(h, hhat) return h, nil }
// CorrMultiBankFFT computes the correlation of // a multi-channel image with a bank of multi-channel filters. // h_p[u, v] = sum_q (f_q corr g_pq)[u, v] func CorrMultiBankFFT(f *rimg64.Multi, g *MultiBank) (*rimg64.Multi, error) { out := ValidSize(f.Size(), g.Size()) if out.X <= 0 || out.Y <= 0 { return nil, nil } // Determine optimal size for FFT. work, _ := FFT2Size(f.Size()) // Cache FFT of each channel of image. fhat := make([]*fftw.Array2, f.Channels) for i := range fhat { fhat[i] = fftw.NewArray2(work.X, work.Y) copyChannelTo(fhat[i], f, i) fftw.FFT2To(fhat[i], fhat[i]) } curr := fftw.NewArray2(work.X, work.Y) fwd := fftw.NewPlan2(curr, curr, fftw.Forward, fftw.Estimate) defer fwd.Destroy() sum := fftw.NewArray2(work.X, work.Y) bwd := fftw.NewPlan2(sum, sum, fftw.Backward, fftw.Estimate) defer bwd.Destroy() h := rimg64.NewMulti(out.X, out.Y, len(g.Filters)) alpha := complex(1/float64(work.X*work.Y), 0) // For each output channel. for p, gp := range g.Filters { zero(sum) // For each input channel. for q := 0; q < f.Channels; q++ { // Take FFT of this input channel. copyChannelTo(curr, gp, q) fwd.Execute() // h_p[x] = (G_qp corr F_p)[x] // H_p[x] = conj(G_qp[x]) F_p[x] addScaleMul(sum, alpha, curr, fhat[q]) } bwd.Execute() copyRealToChannel(h, p, sum) } return h, nil }
// CorrFFT computes the correlation of an image with a filter. // h[u, v] = (f corr g)[u, v] func CorrFFT(f, g *rimg64.Image) (*rimg64.Image, error) { out := ValidSize(f.Size(), g.Size()) if out.X <= 0 || out.Y <= 0 { return nil, nil } // Determine optimal size for FFT. work, _ := FFT2Size(f.Size()) fhat := fftw.NewArray2(work.X, work.Y) ghat := fftw.NewArray2(work.X, work.Y) // Take forward transforms. copyImageTo(fhat, f) fftw.FFT2To(fhat, fhat) copyImageTo(ghat, g) fftw.FFT2To(ghat, ghat) // Scale such that convolution theorem holds. n := float64(work.X * work.Y) scaleMul(fhat, complex(1/n, 0), ghat, fhat) // Take inverse transform. h := rimg64.New(out.X, out.Y) fftw.IFFT2To(fhat, fhat) copyRealTo(h, fhat) return h, nil }