48 void print_final_values(
int n_u,
int n_v,
float* u,
float**
B,
float* v,
float* umin,
float* umax);
52 #define WLS_VERBOSE FALSE
56 #error CA_N_V needs to be defined!
60 #error CA_N_U needs to be defined!
63 #define CA_N_C (CA_N_U+CA_N_V)
78 for (
int j = 0; j < n; j++) {
79 for (
int i = 0; i < m; i++) {
112 int wls_alloc(
float* u,
float* v,
float* umin,
float* umax,
float**
B,
113 float* u_guess,
float* W_init,
float*
Wv,
float* Wu,
float*
up,
114 float gamma_sq,
int imax) {
116 if(!gamma_sq) gamma_sq = 100000;
117 if(!imax) imax = 100;
124 float A_free[
CA_N_C][CA_N_U];
128 float * A_free_ptr[
CA_N_C];
129 for(
int i = 0; i < n_c; i++)
130 A_free_ptr[i] = A_free[i];
135 int free_index[CA_N_U];
136 int free_index_lookup[CA_N_U];
141 float p_free[CA_N_U];
144 int infeasible_index[CA_N_U]
UNUSED;
145 int n_infeasible = 0;
146 float lambda[CA_N_U];
151 for (
int i = 0; i < n_u; i++) {
152 u[i] = (umax[i] + umin[i]) * 0.5;
155 for (
int i = 0; i < n_u; i++) {
159 W_init ? memcpy(W, W_init, n_u *
sizeof(
float))
160 : memset(W, 0, n_u *
sizeof(
float));
162 memset(free_index_lookup, -1, n_u *
sizeof(
float));
166 for (
int i = 0; i < n_u; i++) {
168 free_index_lookup[i] = n_free;
169 free_index[n_free++] = i;
174 for (
int i = 0; i < n_v; i++) {
176 b[i] =
Wv ? gamma_sq *
Wv[i] * v[i] : gamma_sq * v[i];
178 for (
int j = 0; j < n_u; j++) {
180 A[i][j] =
Wv ? gamma_sq *
Wv[i] *
B[i][j] : gamma_sq *
B[i][j];
181 d[i] -=
A[i][j] * u[j];
184 for (
int i = n_v; i < n_c; i++) {
185 memset(
A[i], 0, n_u *
sizeof(
float));
186 A[i][i - n_v] = Wu ? Wu[i - n_v] : 1.0;
187 b[i] =
up ? (Wu ? Wu[i-n_v] *
up[i-n_v] :
up[i-n_v]) : 0;
188 d[i] =
b[i] -
A[i][i - n_v] * u[i - n_v];
192 while (iter++ < imax) {
194 memset(
p, 0, n_u *
sizeof(
float));
195 memcpy(u_opt, u, n_u *
sizeof(
float));
198 if (free_chk != n_free) {
199 for (
int i = 0; i < n_c; i++) {
200 for (
int j = 0; j < n_free; j++) {
201 A_free[i][j] =
A[i][free_index[j]];
222 for (
int i = 0; i < n_free; i++) {
223 p[free_index[i]] = p_free[i];
224 u_opt[free_index[i]] += p_free[i];
228 for (
int i = 0; i < n_u; i++) {
229 if (u_opt[i] >= (umax[i] + 1.0) || u_opt[i] <= (umin[i] - 1.0)) {
230 infeasible_index[n_infeasible++] = i;
235 if (n_infeasible == 0) {
237 memcpy(u, u_opt, n_u *
sizeof(
float));
238 memset(lambda, 0, n_u *
sizeof(
float));
241 for (
int i = 0; i < n_c; i++) {
242 for (
int k = 0; k < n_free; k++) {
243 d[i] -= A_free[i][k] * p_free[k];
245 for (
int k = 0; k < n_u; k++) {
246 lambda[k] +=
A[i][k] * d[i];
249 bool break_flag =
true;
252 for (
int i = 0; i < n_u; i++) {
255 if (lambda[i] < -FLT_EPSILON) {
259 if (free_index_lookup[i] < 0) {
260 free_index_lookup[i] = n_free;
261 free_index[n_free++] = i;
275 float alpha = INFINITY;
280 for (
int i = 0; i < n_free; i++) {
281 int id = free_index[i];
282 if(fabs(
p[
id]) > FLT_EPSILON) {
283 alpha_tmp = (
p[id] < 0) ? (umin[
id] - u[
id]) /
p[id]
284 : (umax[id] - u[id]) /
p[
id];
286 alpha_tmp = INFINITY;
288 if (alpha_tmp <
alpha) {
295 for (
int i = 0; i < n_u; i++) {
299 for (
int i = 0; i < n_c; i++) {
300 for (
int k = 0; k < n_free; k++) {
301 d[i] -= A_free[i][k] *
alpha * p_free[k];
305 W[id_alpha] = (
p[id_alpha] > 0) ? 1.0 : -1.0;
307 free_index[free_index_lookup[id_alpha]] = free_index[--n_free];
308 free_index_lookup[free_index[free_index_lookup[id_alpha]]] =
309 free_index_lookup[id_alpha];
310 free_index_lookup[id_alpha] = -1;
320 printf(
"n_c = %d n_free = %d\n", n_c, n_free);
322 printf(
"A_free =\n");
323 for(
int i = 0; i < n_c; i++) {
324 for (
int j = 0; j < n_free; j++) {
325 printf(
"%f ", A_free_ptr[i][j]);
331 for (
int j = 0; j < n_c; j++) {
335 printf(
"\noutput = ");
336 for (
int j = 0; j < n_free; j++) {
337 printf(
"%f ", p_free[j]);
342 void print_final_values(
int n_u,
int n_v,
float* u,
float**
B,
float* v,
float* umin,
float* umax) {
343 printf(
"n_u = %d n_v = %d\n", n_u, n_v);
346 for(
int i = 0; i < n_v; i++) {
347 for (
int j = 0; j < n_u; j++) {
348 printf(
"%f ",
B[i][j]);
354 for (
int j = 0; j < n_v; j++) {
359 for (
int j = 0; j < n_u; j++) {
365 for (
int j = 0; j < n_u; j++) {
366 printf(
"%f ", umin[j]);
371 for (
int j = 0; j < n_u; j++) {
372 printf(
"%f ", umax[j]);