
function err = residual_recursive_utility(a0, params, gridsetup)

% The error function for the recursive utility model 


[kappa0, kappa1, eta, b, s, nu, rhox, stdx, beta, gamma, psi] = params{:};
[P, x, Xm, X_p, N, Nm, fspace] = gridsetup{:};
 
nn      = length(N);
nx      = length(x);
Nmin    = min(N);
Nmax    = max(N);

alpha   = 1 - gamma;
rho     = 1 - 1/psi;

aE      = reshape(a0(1 : nn*nx), [nn nx]);
aJ      = reshape(a0(nn*nx + 1 : end), [nn nx]);

E       = funeval(aE, fspace, N);
J       = funeval(aJ, fspace, N);

q       = kappa0./(E - kappa1); 
q       = min(q, 1);
q(q<0)  = 1;
theta   = ( q.^(-nu) - 1 ).^(1/nu);

V       = theta.*(1 - Nm);
kappat  = kappa0 + kappa1*q; 
C       = Xm.*Nm - kappat.*V;

N_p      = (1 - s)*Nm + q.*V;
N_p      = min( Nmax, max(Nmin, N_p) );

E_p      = zeros(nn, nx, nx);
J_p      = zeros(nn, nx, nx);

for ip = 1 : nx  % given x_t+1 = x(ip)
    E_p(:, :, ip) = reshape(funeval( aE(:, ip), fspace, reshape(N_p, [nn*nx 1]) ), [nn nx]);
    J_p(:, :, ip) = reshape(funeval( aJ(:, ip), fspace, reshape(N_p, [nn*nx 1]) ), [nn nx]);
end

RJ     = zeros(nn, nx);
for ix = 1 : nx
    RJ(:, ix) = ( (squeeze(J_p(:, ix, :)).^alpha) * P(:, ix) ).^(1/alpha);
end

% the J error function
eJ      = J.^rho - ( (1-beta)*C.^rho + beta*RJ.^rho );

% q_p   = kappa./E_p;
q_p     = kappa0./(E_p - kappa1);
q_p     = min(1, q_p);
q_p(q_p<0) = 1;

theta_p  = ( q_p.^(-nu) - 1 ).^(1/nu);
kappat_p = kappa0 + kappa1*q_p; 

lambda_p = zeros(nn, nx, nx);
lambda_p(q_p==1) = kappa0 + kappa1 - E_p(q_p==1);

% N_p, RJ, and C are independent of x(t+1)
N_p     = repmat(N_p, [1 1 nx]);
RJ      = repmat(RJ, [1 1 nx]);
C       = repmat(C, [1 1 nx]);

U_p     = 1 - N_p;
V_p     = theta_p.*U_p;                   % V_p(q_p==1) = 0;  % this command is redundant
C_p     = X_p.*N_p - kappat_p.*V_p ;

% wage next period
W_p      = eta*(X_p + kappat_p.*theta_p) + (1 - eta)*b; 
M_p      = beta*(C_p./C).^(-1/psi) .* (J_p./RJ).^(1/psi - gamma);
inside_p = M_p.*( X_p - W_p + (1 - s)*(kappa0./q_p + kappa1 - lambda_p) );

rhs    = zeros(nn, nx);
for ix = 1 : nx
    rhs(:, ix) = squeeze(inside_p(:, ix, :)) * P(:, ix) ;
end

eV      = E - rhs;
err     = [eV(:); eJ(:)];


