function particleclaw(par)
%PARTICLECLAW Solve conservation law using particles.
%   A numerical scheme for scalar conservation and balance laws
%       u_t+(f(u))_x = g(x,u)
%   in one space dimension.
%   It is a characteristic particle method that is exactly
%   conservative, total variation diminishing, entropy decreasing,
%   and has no numerical dissipation away from shocks.
%   PARTICLECLAW(PAR) needs to be called with a set of parameters
%   PAR, provided by a pclaw_ex file.
%
%   Version 1.0
%   Copyright (c) 2008 Benjamin Seibold and Yossi Farjoun
%   http://math.mit.edu/~seibold/research/particleclaw
%   http://arxiv.org/abs/0809.0726

%===============================================================================
% Copyright (c) 2008 Benjamin Seibold and Yossi Farjoun
% 
% Permission is hereby granted, free of charge, to any person obtaining a copy
% of this software and associated documentation files (the "Software"), to deal
% in the Software without restriction for non-commercial purposes, including
% without limitation the rights to use, copy, modify, merge, publish, and/or
% distribute copies of the Software, and to permit persons to whom the
% Software is furnished to do so, subject to the following conditions:
% 
% The above copyright notice and this permission notice shall be included in
% all copies or substantial portions of the Software, and credit has to be
% given to the authors in publications that are in any form based on this
% Software.
% 
% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
% OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
% THE SOFTWARE.
%===============================================================================

%-------------------------------------------------------------------------------
% initialize parameters
try, par.name; catch, par.name = ''; end
try, par.u_ip; catch, par.u_ip = []; end
try, par.g; catch, par.g = 'none'; end
par.flag_source = isa(par.g,'function_handle');
try, par.ubox; catch, par.ubox = [0 1]; end
if length(par.d)==2, par.d = par.d([1 2 2]); end
try, par.tstart; catch, par.tstart = 0; end
try, par.dtmax; catch, par.dtmax = inf; end
try, par.odestep; par.flag_source = 1; catch, par.odestep = @odestep; end
try, par.flag_plot; catch, par.flag_plot = 0; end
try, par.flag_sharpen; catch, par.flag_sharpen = 0; end
try, par.flag_save; catch, par.flag_save = 0; end
%-------------------------------------------------------------------------------
% initialize function and particles
[x,u] = sample_ic(par);
[x,u] = make_correct_data(par,x,u);
s = x*0;
x0 = x; u0 = u; [x0c,u0c] = curve(par,x0,u0);
dt = par.tfinal/par.steps;
%-------------------------------------------------------------------------------
% loop over number of steps
for k = 0:par.steps
    if k>0, [x,u,s] = evolve(par,x,u,s,dt); end
    if par.flag_sharpen
        [xp,up] = sharpen(par,x,u,s);
    else
        xp = x; up = u;
    end
    t = par.tstart+k*dt;
    if par.flag_plot % plot results
        clf
        set(gcf,'DoubleBuffer','on');
        if length(par.u_ip)>0, plot(par.xbox,[1;1]*par.u_ip,'k:'), end
        hold on
        plot(x0,u0,'b.',x0c,u0c,'b:')
        if k==0
            try, par.ubox; catch, par.ubox = [min(u0c) max(u0c)]; end
        end
        [xpc,upc] = curve(par,xp,up); plot(xp,up,'r.',xpc,upc,'r-')
        if not(par.flag_sharpen), plot(x(s~=0),u(s~=0),'ro'), end
        hold off
        title(sprintf('%s  t=%0.2f',par.name,t))
        axis([par.xbox(1:2) par.ubox(1:2)])
        drawnow
    end
    if par.flag_save % save results
        filename = sprintf('pclaw.t%04d',k);
        fid = fopen(filename,'w');
        fprintf(fid,'%16.8E    time\n',t);
        fclose(fid);
        filename = sprintf('pclaw.q%04d',k);
        fid = fopen(filename,'w');
        for l = 1:length(xp)
            fprintf(fid,'%16.8E%18.8E\n',xp(l),up(l));
        end
        fclose(fid);
    end
end
%-------------------------------------------------------------------------------

function [x,u,s] = evolve(par,x,u,s,t)
% Forward time evolution of the numerical solution.
% input:  x,u,s  positions, values and merge-labels of points
%         t      time span to evolve
% output: x,u,s  positions, values and merge-labels after time span t
while t>1e-15
    f1u = feval(par.f1,u);
    df1 = diff(f1u);
    ti = -diff(x)./df1; % collision time for neighboring particles
    tb = par.d(2)./[u(1);-u(end)]; % boundary exposure time
    dt0 = min([ti(df1<0);tb(tb>0);Inf]); % shortest particle collision time
    dt0 = max(dt0,1e-12); % minimum time step for unprecise ODE solvers
    dt = min([dt0,par.dtmax,t]);
    if par.flag_source % a source term exists
        while 1 % if particles overtake too much, redo step with smaller dt
            [xn,un] = feval(par.odestep,par,x,u,dt);
            if all(xn(3:end)-x(1:end-2)>0), break, end
            dt = dt*.5;
        end
        x = xn; u = un;
    else % no source term exists
        x = x+dt*f1u;
    end
    t = t-dt;
    % insert points at boundaries if necessary
    while x(1)>par.xbox(1)
        x = [x(1)-par.d(2);x]; u = u([1 1:end]); s = [0;s]; % Neumann b.c.
    end
    while x(end)<par.xbox(2)
        x = [x;x(end)+par.d(2)]; u = u([1:end end]); s = [s;0]; % Neumann b.c.
    end
    % remove points outside boundaries
    i_in = x>=par.xbox(1)-par.d(2)&x<=par.xbox(2)+par.d(2);
    x = x(i_in); u = u(i_in); s = s(i_in);
    % perform particle management, if necessary
    if dt==dt0, [x,u,s] = particle_management(par,x,u,s); end
end

function [xn,un] = odestep(par,x,u,dt)
% One time step of characteristic ODE. Here by explicit trapezoidal rule.
% input:  x,u    positions and values of points
%         dt     time step
% output: xn,un  positions and values after time step
dx1 = feval(par.f1,u); du1 = feval(par.g,x,u); 
x2 = x+dt*dx1; u2 = u+dt*du1;
dx2 = feval(par.f1,u2); du2 = feval(par.g,x2,u2); 
xn = x+dt*(dx1+dx2)/2; un = u+dt*(du1+du2)/2;

%-------------------------------------------------------------------------------
% Particle management
%-------------------------------------------------------------------------------
function [x,u,s] = particle_management(par,x,u,s)
% Particle management at frozen time: merging and inserting.
% input:  x,u,s  positions, values and merge-labels of points
% output: x,u,s  positions, values and merge-labels after management
while 1 % insert points
    dist = diff(x);
    [mv,mi] = max(dist);
    if abs(mv)<=par.d(3), break, end
    [xn,un] = insert(par,x(mi:mi+1),u(mi:mi+1));
    u = [u(1:mi);un;u(mi+1:end)];
    x = [x(1:mi);xn;x(mi+1:end)];
    s = [s(1:mi);0;s(mi+1:end)];
end
while 1 % merge points
    [mv,mi] = min_dist(par,x,u);
    if mv>par.d(1), break, end
    if mi+2>length(x) % merging at right boundary (just add a point)
        x = [x;x(end)+par.d(3)];
        u = u([1:end end]);
        s = [s;0];
    elseif mi-1<1 % merging at left boundary (just add a point)
        x = [x(1)-par.d(3);x];
        u = u([1 1:end]);
        s = [0;s];
    else % actual merging in center
        if (u(mi)-par.u_ip)*(u(mi+2)-par.u_ip)<0
            % first shock point inflection point
            [xn,un] = merge_inflection(par,x(mi-1:mi+3),u(mi-1:mi+3));
            x = [x(1:mi-1);xn;x(mi+3:end)];
            u = [u(1:mi-1);un;u(mi+3:end)];
            s = [s(1:mi-1);1;0;s(mi+3:end)];
        elseif (u(mi-1)-par.u_ip)*(u(mi+1)-par.u_ip)<0
            % second shock point inflection point
            [xn,un] = merge_inflection(par,x(mi+2:-1:mi-2),u(mi+2:-1:mi-2));
            x = [x(1:mi-2);xn([2 1]);x(mi+2:end)];
            u = [u(1:mi-2);un([2 1]);u(mi+2:end)];
            s = [s(1:mi-2);0;1;s(mi+2:end)];
        else % no inflection point on shock
            [xn,un,isp] = merge(par,x(mi-1:mi+2),u(mi-1:mi+2));
            sn = (1:length(xn))'==isp;
            x = [x(1:mi-1);xn;x(mi+2:end)];
            u = [u(1:mi-1);un;u(mi+2:end)];
            s = [s(1:mi-1);sn;s(mi+2:end)];
        end
    end
end

function [mv,mi] = min_dist(par,x,u)
% Smallest decreasing distance between points.
% input:  x,u    positions and values of all points
% output: mv,mi  value and index of smallest decreasing distance
dv = diff(feval(par.f1,u));
dx = diff(x);
dx(dv>=0) = inf;
[mv,mi] = min(dx);

function [xn,un] = insert(par,x,u)
% Between two points, insert a new point on the interpolating curve.
% input:  x,u    positions and values of two points
% output: xn,un  position and value of inserted point
un = (u(1)+u(2))/2;
xn = interp_loc_inv(par,x,u,un);

function [xn,un,isp] = merge(par,x,u)
% Merge two particles (none an inflection point) to one,
% using the points before and behind for preserving area.
% input:  x,u    positions and values of four points
% output: xn,un  position and value of one middle point
%                or more if required by entropy fix
%         isp    index of merged (shock) particle
Aold = diff(x)'*a(par,u(1:3),u(2:4));
xn = (x(2)+x(3))/2;
un = middle_point(par,[x(1);xn;x(4)],u([1 4]),Aold,...
     (u(2)+u(3))/2,[min(u) max(u)]);
isp = 1; % index of shock point
flag_fix = [(un-u(1))*(un-u(2))<0,(un-u(3))*(un-u(4))<0]; % check entropy
if any(flag_fix) % entropy fix
    if flag_fix(1) % insert point left of shock
        uadd = (u(1)+u(2))/2;
        x(1) = interp_loc_inv(par,x(1:2),u(1:2),uadd);
        u(1) = uadd;
    end
    if flag_fix(2) % insert point right of shock
        uadd = (u(3)+u(4))/2;
        x(4) = interp_loc_inv(par,x(3:4),u(3:4),uadd);
        u(4) = uadd;
    end
    [xn_e,un_e,isp] = merge(par,x,u);
    xn = [x(1,flag_fix(1));xn_e;x(4,flag_fix(2))];
    un = [u(1,flag_fix(1));un_e;u(4,flag_fix(2))];
    isp = isp+length(x(1,flag_fix(1)));
end

function [xn,un] = merge_inflection(par,x,u)
% Merge three particles (middle one inflection point) to two,
% using the points before and behind for preserving area.
% input:  x,u    positions and values of five points
% output: xn,un  positions and values of two middle points
d = diff(x);
order = sign(x(end)-x(1));
av = a(par,u(1:4),u(2:5));
Aold = d(1:3)'*av(1:3);
a13 = a(par,u(1),u(3));
x3bar = (Aold+x(1)*a13-x(4)*av(3))/(a13-av(3));
if x3bar*order<x(4)*order % removing point 2, moving point 3
    xn = [x3bar;x(4)]; un = [u(3);u(4)];
    return
end
Aold = Aold+d(4)*av(4);
x3bar = (Aold+x(1)*a13-x(5)*av(4))/(a13-av(4));
if x3bar*order<x(5)*order % removing point 2, moving points 3 and 4
    xn = [x3bar;x3bar]; un = [u(3);u(4)];
    return
end
% removing point 4, moving point 3, lowering point 2
x23 = (x(2)+x(3))/2;
u2bar = middle_point(par,[x(1);x23;x(5)],u([1 3]),Aold,...
        (u(2)+u(3))/2,[min(u) max(u)]);
xn = [x23;x(5)]; un = [u2bar;u(3)];

function ubar = middle_point(par,x,u,Aold,ubar,u_range)
% For three positions and left and right value given, find middle
% value to generate required area.
% input:  x        positions of three points
%         u        values of left and right point
%         Aold     area to be achieved
%         ubar     initial guess
%         u_range  two values with the correct value in between
% output: ubar  value of middle point
d = diff(x);
while 1
    for k = 1:20 % Newton iteration
        F = d'*a(par,ubar*[1;1],u)-Aold;
        if abs(F)<1e-10, return, end
        F1 = d'*a1(par,ubar*[1;1],u);
        du = F./F1;
        ubar = ubar-du;
    end
    % if Newton iteration failed, do some bisection steps
    a_range = d'*a(par,[1;1]*u_range,u*[1 1])-Aold;
    for k = 1:5
        ubar = (u_range(1)+u_range(2))/2;
        abar = d'*a(par,ubar*[1;1],u)-Aold;
        inew = 1+(abar*a_range(1)<0);
        u_range(inew) = ubar; a_range(inew) = abar;
    end
    ubar = (u_range(1)+u_range(2))/2;
end

%-------------------------------------------------------------------------------
% Average function
%-------------------------------------------------------------------------------
function y = a(par,u,v)
% Flux weighted average function
% input:  u,v  two values
% output:      average value
fu = feval(par.f,u); fv = feval(par.f,v);
f1u = feval(par.f1,u); f1v = feval(par.f1,v);
denom = f1u-f1v;
ic = abs(denom)<1e-12;
denom(ic) = eps;
y = ((f1u.*u-f1v.*v)-(fu-fv))./denom;
y(ic) = u(ic);

function y = a1(par,u,v)
% Derivative of average function w.r.t. first argument.
% input:  u,v  two values
% output:      da/du
fu = feval(par.f,u); fv = feval(par.f,v);
f1u = feval(par.f1,u); f1v = feval(par.f1,v);
f2u = feval(par.f2,u); 
denom = f1u-f1v;
ic = abs(denom)<1e-12;
denom(ic) = eps;
y = f2u.*((fu-fv)-f1v.*(u-v))./denom.^2;
y(ic) = .5;

%-------------------------------------------------------------------------------
% Interpolation
%-------------------------------------------------------------------------------
function [xc,uc] = curve(par,x,u)
% Conservative interpolation function.
% input:  x,u    positions and values of points
% output: xc,uc  interpolation function
m = 10; % refinement factor
xc = [x(1);reshape(ones(m,1)*x(2:end)',[],1)];
uc = [u(1);reshape(ones(m,1)*u(2:end)',[],1)];
for i = 2:length(x)
    ui = linspace(u(i-1),u(i),m+1)'; ui = ui(2:end-1);
    xi = interp_loc_inv(par,x([i-1 i]),u([i-1 i]),ui);
    xc(m*(i-2)+(2:m)) = xi;
    uc(m*(i-2)+(2:m)) = ui;
end

function xi = interp_loc_inv(par,x,u,ui)
% Interpolation between two points: yields positions to values.
% input:  x,u  positions and values of two points
%         ui   vector of values in between
% output: xi   vector of positions in between
den = feval(par.f1,u(2))-feval(par.f1,u(1));
if abs(den)<1e-12|feval(par.f2,u(1))*feval(par.f2,u(2))<0
    % equal value or inflection point in between
    xi = linspace(x(1),x(2),length(ui)+2)'; xi = xi(2:end-1);
else % conservative interpolation
    xi = x(1)+(x(2)-x(1))*(feval(par.f1,ui)-feval(par.f1,u(1)))/den;
end

%-------------------------------------------------------------------------------
% Preprocessing and postprocessing
%-------------------------------------------------------------------------------
function [x,u] = sample_ic(par)
% Generate initial positions and values, such that points are very close
% near discontinuities.
% input:       only par
% output: x,u  positions and values of initial condition
x = linspace(par.xbox(1)-par.d(2),par.xbox(2)+par.d(2),...
    ceil(diff(par.xbox)/par.d(2)+4))'; % sample slightly finer than dmax
u = feval(par.ic,x);
% refine points into discontinuities
mv = inf;
slope = 2.^(1:20)';
for k = 1:1e5
    dx = diff(x);
    du = abs(diff(u)).*(dx>0);
    [mv,mi] = max(du);
    sl = mv./dx(mi);
    if isnan(sl)|sl==0, sl = eps; end
    slope = [slope(2:end);sl];
    % no more shocks, if .7 of latest slopes have increased by less than 1.8
    if sum(slope(2:end)./slope(1:end-1)<1.8)>(length(slope)-1)*.7, break, end
    if dx(mi)<1e-10
        x(mi:mi+1) = [1;1]*(x(mi)+x(mi+1))/2;
    else
        xm = (x(mi)+x(mi+1))/2;
        x = [x(1:mi);xm;x(mi+1:end)];
        u = [u(1:mi);feval(par.ic,xm);u(mi+1:end)];
    end
end
if k<length(slope) % there were no discontinuities
    x = linspace(par.xbox(1)-par.d(2),par.xbox(2)+par.d(2),...
        ceil(diff(par.xbox)/par.d(2)+3))';
    u = feval(par.ic,x);
    return
end
% remove points around discontinuities
mv = 0;
while mv<par.d(2)*.9
    dx = diff(x);
    s = abs(diff(u)./dx);
    dx(s>1e3) = inf;
    d2x = dx(1:end-1)+dx(2:end);
    [mv,mi] = min(d2x);
    x = [x(1:mi);x(mi+2:end)];
    u = [u(1:mi);u(mi+2:end)];
end

function [x,u] = make_correct_data(par,x,u)
% Insert points on and between inflection points.
% input:  x,u  positions and values of points
% output: x,u  positions and values of points after correction
[x,i] = sort(x); u = u(i);
uip = sort(par.u_ip); uip = uip(:);
% move points that are very close to an inflection point onto it
[i,j] = find(abs(u*ones(size(uip'))-ones(size(u))*uip')<1e-8);
u(i) = uip(j);
% insert additional points on inflection points
U = u*ones(size(uip')); Uip = ones(size(u))*uip';
[i,j] = find(diff(U>Uip)&diff(U<Uip));
x_sp = x(i)+(uip(j)-u(i))./(u(i+1)-u(i)).*(x(i+1)-x(i));
for k = 1:length(i)
    x = [x(1:i(k));x_sp(k);x(i(k)+1:end)];
    u = [u(1:i(k));uip(j(k));u(i(k)+1:end)];
end
% insert points between inflection points
[i,j] = find(u*ones(size(uip'))==ones(size(u))*uip');
[i,s] = sort(i); j = j(s);
in = find(diff(i)==1);
x = [x;(x(i(in))+x(i(in+1)))/2]; % need to place on curve for sure
u = [u;(uip(j(in))+uip(j(in+1)))/2];
[x,i] = sort(x); u = u(i);

function [x,u,s] = sharpen(par,x,u,s)
% Postprocessing: around each merged particle, create a sharp shock.
% input:  x,u,s  positions, values and merge-labels of points
% output: x,u,s  positions, values and merge-labels after sharpening
i = find(s);
Aold = (x(i)-x(i-1)).*a(par,u(i-1),u(i))+(x(i+1)-x(i)).*a(par,u(i),u(i+1));
xs = (x(i-1).*u(i-1)-x(i+1).*u(i+1)+Aold)./(u(i-1)-u(i+1));
x(i) = xs; u(i) = u(i-1);
x = [x;xs]; u = [u;u(i+1)];
[x,i] = sort(x); u = u(i);
s = x*0;

