#include <algorithm>
#include <set>
#include <stack>
#include <vector>
#include <cstdlib>
#include <ctime>

#include "Maze.h"
#include "DisjointSets.h"

using namespace std;

//
// Maze
//

Maze::Maze( int w_, int h_, unsigned int seed_ ) 
   : maze(0), w(w_), h(h_), seed(seed_), start(0), goal(0)
{
   if (w<=0 || w>=MaxWidth || h<=0 || h>=MaxHeight)
      throw IllegalSizeException();
   maze = new Cell[w*h];

   if (seed==0)
      seed = (unsigned int)std::time(0);
   std::srand(seed);
}


Maze::Maze( const Maze& m ) : maze(0), w(m.w), h(m.h), seed(m.seed)
{
   for (int i=0; i<w*h; i++)
      maze[i] = m.maze[i];
}


Maze::~Maze()
{
   delete[] maze;
}


void Maze::solve( int sx, int sy, int gx, int gy )
{
   if (!maze || sx<0||sx>=w||sy<0||sy>=h || gx<0||gx>=w||gy<0||gy>=h)
      throw IllegalCellRefException();

   start = maze+sy*w+sx;
   goal = maze+gy*w+gx;

   solve();
}


void Maze::init()
{
   Cell initial = {1,1,1,1,0,0};

   Cell *m = maze;
   for (int j=0; j<h; j++)
      for (int i=0; i<w; i++,m++)
	 *m = initial;
}


void Maze::generateAlgoDepthSearch()
{
   int current = std::rand()%(w*h);
   int visited = 0;

   vector<int> v;
   v.push_back(current);

   while (visited<w*h)
   {
      int i = std::rand()%v.size();
      current = v[i];

      int walls[4];
      walls[0] = current-w>=0 ? current-w : -1;
      walls[1] = (current+1)/w==current/w ? current+1 : -1;
      walls[2] = current+w<w*h ? current+w : -1;
      walls[3] = current-1>=0&&(current-1)/w==current/w ? current-1 : -1;

      int down[4];
      int count = 0;
      for (int j=0; j<4; j++)
	 if (walls[j]!=-1 && !maze[walls[j]].visited)
	    down[count++] = j;
      if (count==0)
      {
	 v.erase(v.begin()+i);
	 continue;
      }

      int wall = down[std::rand()%count];
      int next = walls[wall];
      maze[next].visited = 1;
      visited++;

      if (wall==0)
	 maze[current].north = maze[next].south = 0;
      else if (wall==1)
	 maze[current].east = maze[next].west = 0;
      else if (wall==2)
	 maze[current].south = maze[next].north = 0;
      else if (wall==3)
	 maze[current].west = maze[next].east = 0;

      v.push_back(next);
   }
}


void Maze::generateAlgoDisjoint()
{
   DisjointSets<int> s;
   vector<int> v;

   int cell = 0;
   for (int y=0; y<h; y++)
   {
      for (int x=0; x<w; x++,cell++)
      {
	 s.insert_elem(cell);
	 v.push_back(cell);
      }
   }

   while (v.size()>0)
   {
      vector<int>::iterator i = v.begin()+std::rand()%v.size();
      int c = *i;
      v.erase(i);
      
      if ((c+1)/w==c/w && maze[c].east && s.find_elem(c)!=s.find_elem(c+1))
      {
	 maze[c].east = maze[c+1].west = 0;
	 s.union_elem(c,c+1);
      }

      if (c+w<w*h && maze[c].south && s.find_elem(c)!=s.find_elem(c+w))
      {
	 maze[c].south = maze[c+w].north = 0;
	 s.union_elem(c,c+w);
      }
   }
}


void Maze::generate( GenAlgo algo )
{
   switch (algo)
   {
   case DepthSearch:
      generateAlgoDepthSearch();
      break;

   case Disjoint:
      generateAlgoDisjoint();
      break;

   default:
      throw IllegalAlgoException();
   }
}


void Maze::solve()
{
   Cell *m = maze;
   for (int j=0; j<h; j++)
      for (int i=0; i<w; i++,m++)
	 m->path = m->visited = 0;

   stack<Cell *> s;
   Cell *mp = start;
   mp->visited = true;
   s.push(mp);
   for (;;)
   {
      if (!mp->north && !(mp-w)->visited)
	 s.push(mp-w);
      if (!mp->east && !(mp+1)->visited)
	 s.push(mp+1);
      if (!mp->south && !(mp+w)->visited)
	 s.push(mp+w);
      if (!mp->west && !(mp-1)->visited)
	 s.push(mp-1);
      if (s.top()==mp)
	 s.pop();
      mp = s.top();
      mp->visited = 1;
      if (mp==goal)
	 break;
   }

   path.clear();
   while (!s.empty())
   {
      m = s.top(); 
      s.pop();
      if ((m->path=m->visited)!=0)
	 path.push_back(m);
   }
}


//
// DrawableMaze
//

void DrawableMaze::drawMaze()
{
   for (int i=0; i<getW(); i++)
      if (cellAt(i,0).north)
	 this->drawWallEvent(i,0,i+1,0);

   for (int j=0; j<getH(); j++)
   {
      if (cellAt(0,j).west)
	 this->drawWallEvent(0,j,0,j+1);

      for (int i=0; i<getW(); i++)
      {
	 this->drawCellEvent(i,j,i+1,j+1);

	 Cell c = cellAt(i,j);
	 if (c.south)
	    this->drawWallEvent(i,j+1,i+1,j+1);
	 if (c.east)
	    this->drawWallEvent(i+1,j,i+1,j+1);
      }
   }
   
   int n = nPathCoords();
   if (n>0)
   {
      CellCoords c1 = pathCoords(0);
      for (int i=1; i<n; i++)
      {
	 CellCoords c2 = pathCoords(i);
	 this->drawPathEvent(c1.x,c1.y,c2.x,c2.y);
	 c1 = c2;
      }
   }

   if (start)
      this->drawStartEvent((start-maze)%w,(start-maze)/w);

   if (goal)
      this->drawGoalEvent((goal-maze)%w,(goal-maze)/w);
}
