//compute x_i while fixing the other variable
func solve_lasso_CD(p *readData.Problem) {

	//initial pred is zeros, since x is zero vector, hence the residual is just the label vector
	residual := make([]float32, p.L)
	z := make([]float32, p.L)
	for i := 0; i < p.L; i++ {
		residual[i] = float32(p.Labels[i])
	}
	copy(z, residual)
	//	pred := make([]float32, p.L)
	fea_square := make([]float32, p.N)
	for i := 0; i < p.N; i++ {
		fea_square[i] = p.A_cols[i].Multiply_sparse_vector(&(p.A_cols[i]))
	}
	//	fmt.Printf("%v ", fea_square)
	obj_old := get_obj(p)
	fmt.Printf("obj: %f\n", obj_old)
	var iter int
	for iter = 1; iter < p.Max_iter; iter++ {
		for n := 0; n < p.N; n++ {
			update_z(z, residual, p, n)
			temp := p.A_cols[n].Multiply_dense_array(z) / fea_square[n]
			p.X[n] = soft_threshold(temp, float32(p.L)*p.Lambda/fea_square[n])
			update_residual(residual, z, p, n)
		}
		obj_new := get_obj(p)
		fmt.Printf("obj: %f\n", obj_new)
		if mathOp.Abs(obj_new-obj_old) < p.Epsilon {
			break
		}
		obj_old = obj_new
	}
	fmt.Printf("converged in %d iterations\n", iter)
}
//ref LIBLINEAR -- A Library for Large Linear Classification
func get_upper_bound(p *readData.Problem, dr float32, j int) float32 {
	var upper_1 float32
	var upper_2 float32
	var s1 float32
	var s2 float32
	var s_y_1 float32
	var s_y_0 float32 //for label -1
	nnz := len(p.A_cols[j].Idxs)
	for i := 0; i < len(p.A_cols[j].Idxs); i++ {
		sample_index := p.A_cols[j].Idxs[i]
		s1 = s1 + p.A_cols[j].Values[i]/(1.0+(mathOp.Exp(p.Ax[sample_index])))
		s2 = s2 + p.A_cols[j].Values[i]/(1.0+(mathOp.Exp(-p.Ax[sample_index])))
		if p.Labels[sample_index] == 1 {
			s_y_1 = s_y_1 + p.A_cols[j].Values[i]
		} else {
			s_y_0 = s_y_0 + p.A_cols[j].Values[i]
		}
	}
	s1 = s1 / float32(p.L)
	s1 = s1 * (mathOp.Exp(-dr*p.Xj_max[j]) - 1.0)
	s1 = s1 / p.Xj_max[j] * float32(nnz) / float32(p.L)
	upper_1 = float32(nnz) / float32(p.L) * (mathOp.Log(1.0 + s1))
	s_y_0 = dr*s_y_0/float32(p.L) + p.Lambda*(mathOp.Abs(p.X[j]+dr)-mathOp.Abs(p.X[j]))
	upper_1 = upper_1 + s_y_0

	s2 = s2 / float32(p.L)
	s2 = s2 * (mathOp.Exp(dr*p.Xj_max[j]) - 1.0)
	s2 = s2 / p.Xj_max[j] * float32(nnz) / float32(p.L)
	upper_2 = float32(nnz) / float32(p.L) * (mathOp.Log((1.0 + s2)))
	s_y_1 = -dr*s_y_1/float32(p.L) + p.Lambda*(mathOp.Abs(p.X[j]+dr)-mathOp.Abs(p.X[j]))
	upper_2 = upper_2 + s_y_1

	if upper_1 < upper_2 {
		return upper_1
	} else {
		return upper_2
	}
}
func get_obj_lr(p *readData.Problem) float32 {
	var loss float32
	for i := 0; i < p.L; i++ {
		//		ti := (1.0 + float32(math.Exp(float64(p.A_rows[i].Multiply_dense_array(p.X)*float32(-p.Labels[i])))))
		ti := (1.0 + float32(math.Exp(float64(p.Ax[i]*float32(-p.Labels[i])))))
		ti = float32(math.Log(float64(ti)))
		loss = loss + ti
	}
	loss = loss / float32(p.L)
	var l1 float32
	for i := 0; i < p.N; i++ {
		l1 = l1 + mathOp.Abs(p.X[i])
	}
	return loss + l1*p.Lambda
}
//solve lasso with coordinate descent
func get_obj(p *readData.Problem) float32 {
	var loss float32
	var l1 float32
	var t float32
	loss = 0
	l1 = 0
	for i := 0; i < p.L; i++ {
		t = float32(p.Labels[i]) - p.A_rows[i].Multiply_dense_array(p.X)
		loss = loss + t*t
	}
	loss = loss / float32(p.L) / 2.0
	for i := 0; i < p.N; i++ {
		l1 = l1 + mathOp.Abs(p.X[i])
	}
	return loss + l1*p.Lambda
}
func get_obj(p *readData.Problem, u []float32, w []float32) float32 {
	var loss float32
	var l1 float32
	var t float32
	loss = 0
	l1 = 0
	for i := 0; i < p.L; i++ {
		//		t = u[i] - p.A_rows[i].Multiply_dense_array(p.X)
		t = u[i] - p.Ax[i]
		loss = loss + t*t*w[i]
	}
	for i := 0; i < p.N; i++ {
		l1 = l1 + mathOp.Abs(p.X[i])
	}
	return loss + l1*p.Lambda
}
//TODO  speed it up (liblinear)
func Solve_lr_CD(p *readData.Problem) {
	obj_old := get_obj_lr(p)
	fmt.Printf("iter: 0, obj: %f\n", obj_old)
	for i := 1; i <= p.Max_iter; i++ {
		u, w := get_u_and_w(p)
		solve_weighted_lasso_CD(p, u, w)
		obj_new := get_obj_lr(p)
		if i%10 == 0 {
			fmt.Printf("iter: %d, obj: %f\n", i, obj_new)
		}
		if mathOp.Abs(obj_new-obj_old) < p.Epsilon*obj_old {
			break
		}
		obj_old = obj_new
	}
}
// using CD to solve weigthed lasso
//compute x_i while fixing the other variable
func solve_weighted_lasso_CD(p *readData.Problem, u []float32, w []float32) {
	//initial pred is zeros, since x is zero vector, hence the residual is just the label vector
	residual := make([]float32, p.L)
	z := make([]float32, p.L)
	for i := 0; i < p.L; i++ {
		//		residual[i] = u[i] - p.A_rows[i].Multiply_dense_array(p.X)
		residual[i] = u[i] - p.Ax[i]
	}
	copy(z, residual) // can  also be deleted
	//	pred := make([]float32, p.L)
	fea_weithed_norm := make([]float32, p.N) // time consuming
	for i := 0; i < p.N; i++ {
		temp := p.A_cols[i].Dot_product(w)
		fea_weithed_norm[i] = p.A_cols[i].Multiply_sparse_vector(temp)
	}
	//	fmt.Printf("%v ", fea_square)
	obj_old := get_obj(p, u, w)
	//	fmt.Printf("    inner obj: %f\n", obj_old)
	var iter int
	for iter = 1; iter < 100; iter++ {
		for n := 0; n < p.N; n++ {
			weighted_lasso_update_z(z, residual, p, n)
			//			weighted_lasso_update_z_0(z, p, n)
			//only nonzeros entries of z are used
			temp := p.A_cols[n].Multiply_dense_array_weithted(z, w) / fea_weithed_norm[n]
			x_new := soft_threshold(temp, p.Lambda/fea_weithed_norm[n]/2.0)
			update_Ax(p, p.X[n], x_new, n)
			p.X[n] = x_new
			weighted_lasso_update_residual(residual, z, p, n)
			//			weighted_lasso_update_z_1(z, p, n)
		}
		obj_new := get_obj(p, u, w)
		if obj_new > obj_old {
			fmt.Printf("    wrong\n")
		}
		//		fmt.Printf("    inner obj: %f\n", obj_new)
		if mathOp.Abs(obj_new-obj_old) < 0.1*obj_old { //!!
			break
		}
		obj_old = obj_new
	}
	//	fmt.Printf("   inner converged in %d iterations\n", iter)
}
//0<sigma<1, 0<r<1
func Solve_lr_new_glmnet_cdn(p *readData.Problem, sigma float32, lambda float32) {
	obj_old := get_obj_lr(p)
	fmt.Printf("iter: 0, obj: %f\n", obj_old)

	for i := 1; i <= p.Max_iter; i++ {
		// shuffle should be added here
		for j := 0; j < p.N; j++ {
			if j%10000 == 0 {
				//				fmt.Printf("j: %d    g_j = %f, H_jj = %f, d=%f\n", j, g_j, H_jj, lambda*d)
				obj_new := get_obj_lr(p)
				fmt.Printf("    iter: %d, obj: %f\n", j, obj_new)
			}
			if len(p.A_cols[i].Idxs) == 0 {
				continue
			}
			g_j, H_jj := get_g_j_and_H_jj(p, j)
			if H_jj < 1e-8 || mathOp.Abs(g_j) < 1e-8 {
				continue
			}

			d := update_d(g_j, H_jj, p.X[j], p.Lambda)
			if mathOp.Abs(d) < 1e-8 {
				continue
			}
			//			fmt.Printf("j: %d    g_j = %f, H_jj = %f, d=%f\n", j, g_j, H_jj, d)
			//			loss_old := get_obj_lr(p)
			var r float32
			r = 1.0
			right_hand_size := g_j*d + p.Lambda*(mathOp.Abs(p.X[j]+d)-mathOp.Abs(p.X[j]))

			//compute upper bound of left hand side of (45)
			var upper_1 float32
			var upper_2 float32
			var upper float32
			var s1 float32
			var s2 float32
			var ss1 float32
			var ss2 float32
			var s_y_1 float32
			var s_y_0 float32 //for label -1
			var ss_y_1 float32
			var ss_y_0 float32
			nnz := len(p.A_cols[j].Idxs)

			for i := 0; i < len(p.A_cols[j].Idxs); i++ {
				sample_index := p.A_cols[j].Idxs[i]
				s1 = s1 + p.A_cols[j].Values[i]/(1.0+(mathOp.Exp(p.Ax[sample_index])))
				s2 = s2 + p.A_cols[j].Values[i]/(1.0+(mathOp.Exp(-p.Ax[sample_index])))
				if p.Labels[sample_index] == 1 {
					s_y_1 = s_y_1 + p.A_cols[j].Values[i]
				} else {
					s_y_0 = s_y_0 + p.A_cols[j].Values[i]
				}
			}
			s1 = s1 / float32(p.L)
			s2 = s2 / float32(p.L)
			for {
				//				loss_new := get_obj_with_d(p, d, j, r) //expensive
				ss1 = s1 * (mathOp.Exp(-d*r*p.Xj_max[j]) - 1.0)
				ss1 = ss1 / p.Xj_max[j] * float32(nnz) / float32(p.L)
				upper_1 = float32(nnz) / float32(p.L) * mathOp.Log(1.0+ss1)
				ss_y_0 = d*r*s_y_0/float32(p.L) + p.Lambda*(mathOp.Abs(p.X[j]+d*r)-mathOp.Abs(p.X[j]))
				upper_1 = upper_1 + ss_y_0

				ss2 = s2 * (mathOp.Exp(d*r*p.Xj_max[j]) - 1.0)
				ss2 = ss2 / p.Xj_max[j] * float32(nnz) / float32(p.L)
				upper_2 = float32(nnz) / float32(p.L) * mathOp.Log(1.0+ss2)
				ss_y_1 = -d*r*s_y_1/float32(p.L) + p.Lambda*(mathOp.Abs(p.X[j]+d*r)-mathOp.Abs(p.X[j]))
				upper_2 = upper_2 + ss_y_1

				if upper_1 < upper_2 {
					upper = upper_1
				} else {
					upper = upper_2
				}
				//				upper := get_upper_bound(p, d*r, j)
				if upper < sigma*r*right_hand_size {
					break
				}
				r = r * lambda
				//				fmt.Printf("r=%f\n", r)
			}

			update_Ax(p, p.X[j], p.X[j]+r*d, j)
			p.X[j] = p.X[j] + r*d
		}

		obj_new := get_obj_lr(p)
		//		if i%10 == 0 {
		fmt.Printf("iter: %d, obj: %f\n", i, obj_new)
		//		}
		if mathOp.Abs(obj_new-obj_old) < p.Epsilon*obj_old {
			break
		}
		obj_old = obj_new
	}
}