#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <unistd.h>

#ifdef USE_NLLIBC
#include <nllibc.h>
#endif

#include "llllib.h"
#include "readline.h"
#include "string.h"
#include "lll.h"

static void sign_extension_char(FILE *out)
{
  fprintf(out, "\tmovsbq\t%%al, %%rax\n");
}

static void sign_extension_uchar(FILE *out)
{
  fprintf(out, "\tmovzbl\t%%al, %%eax\n");
}

static void sign_extension_short(FILE *out)
{
  fprintf(out, "\tmovswq\t%%ax, %%rax\n");
}

static void sign_extension_ushort(FILE *out)
{
  fprintf(out, "\tmovzwl\t%%ax, %%eax\n");
}

static void sign_extension_int(FILE *out)
{
  fprintf(out, "\tmovslq\t%%eax, %%rax\n");
}

static void sign_extension_uint(FILE *out)
{
  fprintf(out, "\tmov\t%%eax, %%eax\n");
}

static struct regtype {
  char *type;
  int size;
  char *opcode;
  char *reg;
  char *variable;
  void (*extension)(FILE *out);
} regtypes[] = {
  { "CHAR"  , 1, "mov", "al" , ".byte" , sign_extension_char },
  { "UCHAR" , 1, "mov", "al" , ".byte" , sign_extension_uchar },
  { "SHORT" , 2, "mov", "ax" , ".value", sign_extension_short },
  { "USHORT", 2, "mov", "ax" , ".value", sign_extension_ushort },
  { "INT"   , 4, "mov", "eax", ".long" , sign_extension_int },
  { "UINT"  , 4, "mov", "eax", ".long" , sign_extension_uint },
  { "LONG"  , 8, "mov", "rax", ".quad" , NULL },
  { "ULONG" , 8, "mov", "rax", ".quad" , NULL },
  { NULL    , 8, "mov", "rax", ".quad" , NULL },
};

static struct regtype *deftype = NULL;

static struct regtype *regtype_get(const char *type)
{
  struct regtype *regtype;

  for (regtype = regtypes; regtype->type; regtype++) {
    if (type && !strcmp(regtype->type, type))
      break;
  }

  return regtype;
}

static int none(FILE *out, int argc, const char *argv[])
{
  return 0;
}

static int comment(FILE *out, int argc, const char *argv[])
{
  if (argv[1][0] != '"')
    return LLL_ERRCODE_INVALID_FORMAT;
  fprintf(out, "/* %s */\n", argv[1] + 1);
  return 0;
}

static int direct(FILE *out, int argc, const char *argv[])
{
  if (argv[1][0] != '"')
    return LLL_ERRCODE_INVALID_FORMAT;
  fprintf(out, "\t%s\n", argv[1] + 1);
  return 0;
}

static int label(FILE *out, int argc, const char *argv[])
{
  fprintf(out, "%s:\n", argv[1]);
  return 0;
}

static int rval(FILE *out, const char *val, const char *reg)
{
  const char *label;

  if (!strcmp(val, ".")) {
    fprintf(out, "\tpop\t%%%s\n", reg);
  } else if (!strcmp(val, "_")) {
    fprintf(out, "\tmov\t%%rbx, %%%s\n", reg);
  } else if (!strcmp(val, "$")) {
    fprintf(out, "\tmov\t%%rsp, %%%s\n", reg);
  } else if (!strcmp(val, "%")) {
    fprintf(out, "\tmov\t%%rbp, %%%s\n", reg);
  } else if (val[0] == '_') {
    fprintf(out, "\tmov\t0x%lx(%%rbx), %%%s\n", strtol(&val[1], NULL, 0) * deftype->size, reg);
  } else if (val[0] == '$') {
    fprintf(out, "\tmov\t0x%lx(%%rsp), %%%s\n", strtol(&val[1], NULL, 0) * deftype->size, reg);
  } else if (val[0] == '%') {
    fprintf(out, "\tmov\t0x%lx(%%rbp), %%%s\n", strtol(&val[1], NULL, 0) * deftype->size, reg);
  } else if (val[0] == '\'') {
    fprintf(out, "\tmov\t$%s, %%%s\n", val, reg);
  } else if (val[0] == '"') {
    label = string_alloc(val + 1);
    if (!label)
      return LLL_ERRCODE_LESS_MEMORY;
    fprintf(out, "\tmov\t$%s, %%%s\n", label, reg);
  } else if (val[0] == '&') {
    fprintf(out, "\tmov\t$%s, %%%s\n", val + 1, reg);
  } else if (isdigit(val[0])) {
    fprintf(out, "\tmov\t$0x%lx, %%%s\n", strtol(val, NULL, 0), reg);
  } else {
    fprintf(out, "\tmov\t%s(%%rip), %%%s\n", val, reg);
  }

  return 0;
}

static int wval(FILE *out, const char *val, const char *reg)
{
  if (!strcmp(val, ".")) {
    fprintf(out, "\tpush\t%%%s\n", reg);
  } else if (!strcmp(val, "_")) {
    fprintf(out, "\tmov\t%%%s, %%rbx\n", reg);
  } else if (!strcmp(val, "$")) {
    fprintf(out, "\tmov\t%%%s, %%rsp\n", reg);
  } else if (!strcmp(val, "%")) {
    fprintf(out, "\tmov\t%%%s, %%rbp\n", reg);
  } else if (val[0] == '_') {
    fprintf(out, "\tmov\t%%%s, 0x%lx(%%rbx)\n", reg, strtol(&val[1], NULL, 0) * deftype->size);
  } else if (val[0] == '$') {
    fprintf(out, "\tmov\t%%%s, 0x%lx(%%rsp)\n", reg, strtol(&val[1], NULL, 0) * deftype->size);
  } else if (val[0] == '%') {
    fprintf(out, "\tmov\t%%%s, 0x%lx(%%rbp)\n", reg, strtol(&val[1], NULL, 0) * deftype->size);
  } else if (val[0] == '\'') {
    ;
  } else if (val[0] == '"') {
    ;
  } else if (val[0] == '&') {
    ;
  } else if (isdigit(val[0])) {
    ;
  } else {
    fprintf(out, "\tmov\t%%%s, %s(%%rip)\n", reg, val);
  }

  return 0;
}

static int push(FILE *out, int argc, const char *argv[])
{
  int r;
  if ((r = rval(out, argv[1], "rax")) < 0) return r;
  fprintf(out, "\tpush\t%%rax\n");
  return 0;
}

static int pop(FILE *out, int argc, const char *argv[])
{
  int r;
  fprintf(out, "\tpop\t%%rax\n");
  if ((r = wval(out, argv[1], "rax")) < 0) return r;
  return 0;
}

struct ops {
  char *name;
  char *opcode;
};

static struct ops *opsearch(struct ops *ops, const char *name)
{
  for (; ops->name; ops++) {
    if (!strcmp(ops->name, name))
      return ops;
  }
  return NULL;
}

static int calc1(FILE *out, const char *arg, const char *reg)
{
  int r;
  struct ops ops[] = {
    { "=", NULL }, { "+", NULL }, { "-", "neg" }, { "~", "not" },
    { NULL, NULL }
  };
  struct ops *op;
  char name[2];

  if (arg[0] == '\0')
    return LLL_ERRCODE_INVALID_FORMAT;

  name[0] = arg[0];
  name[1] = '\0';

  if ((op = opsearch(ops, name)) == NULL) {
    if ((r = rval(out, arg, reg)) < 0) return r;
  } else {
    if ((r = calc1(out, arg + 1, reg)) < 0)
      return r;
    if (op->opcode)
      fprintf(out, "\t%s\t%%%s\n", op->opcode, reg);
  }

  return 0;
}

static int calc2(FILE *out, int argc, const char *argv[],
		 const char *reg1, const char *reg2)
{
  int r;
  struct ops ops[] = {
    { "+", "add" }, { "-", "sub" },
    { NULL, NULL }
  };
  struct ops *op;

  while (argc > 0) {
    if (argc < 2)
      return LLL_ERRCODE_INVALID_FORMAT;

    if ((op = opsearch(ops, argv[0])) == NULL)
      return LLL_ERRCODE_UNKNOWN_OPERATOR;

    if ((r = calc1(out, argv[1], reg2)) < 0) return r;
    if (op->opcode)
      fprintf(out, "\t%s\t%%%s, %%%s\n", op->opcode, reg2, reg1);

    argc -= 2;
    argv += 2;
  }

  return 0;
}

static int calc(FILE *out, int argc, const char *argv[])
{
  int r;

  if ((r = calc1(out, argv[2], "rax")) < 0)
    return r;

  if ((r = calc2(out, argc - 3, argv + 3, "rax", "rdx")) < 0)
    return r;

  if ((r = wval(out, argv[1], "rax")) < 0) return r;

  return 0;
}

static int load(FILE *out, int argc, const char *argv[])
{
  int r;
  struct regtype *regtype;

  regtype = regtype_get(argv[1]);
  if ((r = rval(out, argv[4], "rax")) < 0) return r;
  fprintf(out, "\t%s\t0x%lx(%%rax), %%%s\n", regtype->opcode,
	  strtol(argv[2], NULL, 0) * regtype->size, regtype->reg);
  if (regtype->extension)
    regtype->extension(out);
  if ((r = wval(out, argv[3], "rax")) < 0) return r;

  return 0;
}

static int store(FILE *out, int argc, const char *argv[])
{
  int r;
  struct regtype *regtype;

  regtype = regtype_get(argv[1]);
  if ((r = rval(out, argv[3], "rax")) < 0) return r;
  if ((r = rval(out, argv[4], "rdx")) < 0) return r;
  fprintf(out, "\t%s\t%%%s, 0x%lx(%%rdx)\n", regtype->opcode,
	  regtype->reg, strtol(argv[2], NULL, 0) * regtype->size);

  return 0;
}

static int jump(FILE *out, int argc, const char *argv[])
{
  int r;

  if (argv[1][0] == '*') {
    if ((r = rval(out, argv[1] + 1, "rax")) < 0) return r;
    fprintf(out, "\tjmp\t*%%rax\n");
  } else {
    fprintf(out, "\tjmp\t%s\n", argv[1]);
  }

  return 0;
}

static int jxx(FILE *out, int argc, const char *argv[])
{
  int r;
  struct ops ops[] = {
    { "==", "je"  }, { "!=", "jne" },
    { "<" , "jl"  }, { ">" , "jg"  },
    { "<=", "jle" }, { ">=", "jge" },
    { NULL, NULL }
  };
  struct ops *op;

  if ((op = opsearch(ops, argv[2])) == NULL)
    return LLL_ERRCODE_UNKNOWN_OPERATOR;

  if ((r = rval(out, argv[1], "rax")) < 0) return r;
  if ((r = rval(out, argv[3], "rdx")) < 0) return r;
  fprintf(out, "\tcmp\t%%rdx, %%rax\n");
  if (op->opcode)
    fprintf(out, "\t%s\t%s\n", op->opcode, argv[4]);

  return 0;
}

static int ret(FILE *out, int argc, const char *argv[])
{
  int r;

  if ((r = rval(out, argv[1], "rax")) < 0) return r;
  fprintf(out, "\tmov\t-0x%x(%%rbp), %%rbx\n", deftype->size * 1);
  fprintf(out, "\tmov\t-0x%x(%%rbp), %%rsi\n", deftype->size * 2);
  fprintf(out, "\tmov\t-0x%x(%%rbp), %%rdi\n", deftype->size * 3);
  fprintf(out, "\tleave\n");
  fprintf(out, "\tret\n");

  return 0;
}

static const char *argregs[] = {
  "rdi", "rsi", "rdx", "rcx", "r8", "r9",
};

static int call(FILE *out, int argc, const char *argv[])
{
  int r, i;

  for (i = 3; i < argc; i++) {
    if ((r = rval(out, argv[i], "rax")) < 0) return r;
    fprintf(out, "\tmov\t%%rax, %%%s\n", argregs[i - 3]);
  }

  fprintf(out, "\tpush\t%%rbx\n");
  fprintf(out, "\tmov\t%%rsp, %%rbx\n");
  fprintf(out, "\tand\t$0xfffffffffffffff0, %%rsp\n");

  if (argv[2][0] == '*') {
    if ((r = rval(out, argv[2] + 1, "rax")) < 0) return r;
    fprintf(out, "\tcall\t*%%rax\n");
  } else {
    fprintf(out, "\tcall\t%s\n", argv[2]);
  }

  fprintf(out, "\tmov\t%%rbx, %%rsp\n");
  fprintf(out, "\tpop\t%%rbx\n");
  if ((r = wval(out, argv[1], "rax")) < 0) return r;

  return 0;
}

static int function(FILE *out, int argc, const char *argv[])
{
  int r, i;

  fprintf(out, "\t.text\n");
  fprintf(out, "\t.globl\t%s\n", argv[1]);
  fprintf(out, "\t.type\t%s, @function\n", argv[1]);
  fprintf(out, "%s:\n", argv[1]);
  fprintf(out, "\tpush\t%%rbp\n");
  fprintf(out, "\tmov\t%%rsp, %%rbp\n");
  fprintf(out, "\tpush\t%%rbx\n");
  fprintf(out, "\tpush\t%%rsi\n");
  fprintf(out, "\tpush\t%%rdi\n");
  fprintf(out, "\tsub\t$0x%lx, %%rsp\n", strtol(argv[2], NULL, 0) * deftype->size);

  for (i = 3; i < argc; i++) {
    fprintf(out, "\tmov\t%%%s, %%rax\n", argregs[i - 3]);
    if ((r = wval(out, argv[i], "rax")) < 0) return r;
  }

  return 0;
}

static int integer(FILE *out, int argc, const char *argv[])
{
  fprintf(out, "\t.data\n");
  fprintf(out, "\t.globl\t%s\n", argv[1]);
  fprintf(out, "\t.align\t%d\n", deftype->size);
  fprintf(out, "\t.type\t%s, @object\n", argv[1]);
  fprintf(out, "\t.size\t%s, %d\n", argv[1], deftype->size);
  fprintf(out, "%s:\n", argv[1]);
  if (argv[2][0] == '\'') {
    fprintf(out, "\t%s\t%s\n\n", deftype->variable, argv[2]);
  } else {
    fprintf(out, "\t%s\t0x%lx\n\n", deftype->variable, strtol(argv[2], NULL, 0));
  }

  return 0;
}

static int string(FILE *out, int argc, const char *argv[])
{
  const char *label;

  if (argv[2][0] != '"')
    return LLL_ERRCODE_INVALID_FORMAT;

  label = string_alloc(argv[2] + 1);
  if (!label)
    return LLL_ERRCODE_LESS_MEMORY;

  fprintf(out, "\t.data\n");
  fprintf(out, "\t.globl\t%s\n", argv[1]);
  fprintf(out, "\t.align\t%d\n", deftype->size);
  fprintf(out, "\t.type\t%s, @object\n", argv[1]);
  fprintf(out, "\t.size\t%s, %d\n", argv[1], deftype->size);
  fprintf(out, "%s:\n", argv[1]);
  fprintf(out, "\t%s\t%s\n\n", deftype->variable, label);

  return 0;
}

static int finish(FILE *out, int argc, const char *argv[])
{
  int r = 0;
  if (argc > 1)
    r = strtol(argv[1], NULL, 0);
  return r + 1;
}

struct command {
  char *name;
  int argnum;
  int (*func)(FILE *out, int argc, const char *argv[]);
} commands[] = {
  { "#"     , 1, none },
  { "REM"   , 2, comment },
  { "ASM"   , 2, direct },
  { "LABEL" , 2, label },
  { "PUSH"  , 2, push },
  { "POP"   , 2, pop },
  { "CALC"  , 3, calc },
  { "LOAD"  , 5, load },
  { "STORE" , 5, store },
  { "GOTO"  , 2, jump },
  { "IF"    , 5, jxx },
  { "RETURN", 2, ret },
  { "CALL"  , 3, call },
  { "FUNC"  , 3, function },
  { "INT"   , 3, integer },
  { "STRING", 3, string },
  { "EXIT"  , 1, finish },
  { NULL, 0, NULL }
};

static int strcmp_toupper(const char *s1, const char *s2)
{
  for (; *s1 || *s2; s1++, s2++) {
    if (toupper(*s1) != toupper(*s2))
      return *s1 - *s2;
  }
  return 0;
}

static int proc(FILE *out, int argc, const char *argv[])
{
  struct command *c;

  for (c = commands; c->name; c++) {
    if (!strcmp_toupper(c->name, argv[0])) {
      if (argc < c->argnum)
	return LLL_ERRCODE_LESS_ARGUMENT;
      return c->func(out, argc, argv);
    }
  }

  return LLL_ERRCODE_UNKNOWN_COMMAND;
}

static int getargs(char *line, const char *argv[], int argnum)
{
  int argc = 0, quote;
  char *p = line;

  while (1) {
    while (*p && isspace(*p))
      p++;
    if ((*p == '\0') || (argc == argnum))
      break;
    argv[argc++] = p;
    quote = (*p == '"') ? *(p++) : 0;
    while (*p && (quote ? (*p != quote) : !isspace(*p))) {
      if (quote && (*p == '\\'))
	p++;
      if (*p)
	p++;
    }
    if (*p)
      *(p++) = '\0';
  }

  return argc;
}

int lll_init(void)
{
  return 0;
}

int lll_done(void)
{
  return 0;
}

static void truncate_line(char *buffer)
{
  int len;
  if (buffer) {
    len = strlen(buffer);
    while (len > 0) {
      if (!isspace(buffer[len - 1]))
	break;
      len--;
    }
    buffer[len] = '\0';
  }
}

int lll_main(int argc, char *argv[], FILE *out)
{
  char *linebuf, *readlinebuf, *head, *filename = "(none)";
  int finished = 0, quit = 0, argn, n = 0, r, retcode = 0;
  FILE *fp = NULL;
  char buffer[LINE_MAXLEN + 1];
  const char *args[ARGS_MAXNUM];

  deftype = regtype_get(NULL);

  while (!finished) {
    linebuf = NULL;
    readlinebuf = NULL;

    if (fp) {
      linebuf = buffer;
      if (!fgets(linebuf, LINE_MAXLEN + 1, fp)) {
	if (fp != stdin)
	  fclose(fp);
	fp = NULL;
	filename = "(null)";
	n = 0;
	continue;
      }
      truncate_line(linebuf);
    } else if (argc > 0) {
      if (!strcmp(argv[0], "-")) {
	fp = stdin;
	filename = "(stdin)";
	n = 0;
      } else {
	fp = fopen(argv[0], "r");
	if (!fp) {
	  LLL_ERRPRINT(LLL_ERRCODE_FILE_NOT_FOUND, argv[0], 0);
	} else {
	  filename = argv[0];
	  n = 0;
	}
      }
      argc--;
      argv++;
      quit = 1;
      continue;
    } else if (quit) {
      finished = 1;
      continue;
    } else {
      lll_finished_clear();
      readlinebuf = lll_readline("lll> ");
      if (lll_is_finished()) { /* break by Ctrl+C */
	if (!readlinebuf)
	  write(1, "\n", 1); /* for libedit */
	linebuf = "";
      } else if (!readlinebuf) { /* exit by Ctrl+D */
	finished = 1;
	continue;
      } else { /* normal input */
	linebuf = readlinebuf;
	truncate_line(linebuf);
      }
      filename = "(command-line)";
      n = 0;
    }

    head = linebuf;
    while (isspace(*head))
      head++;

    if (*head && readlinebuf)
      lll_add_history(head);

    n++;
    argn = getargs(head, args, ARGS_MAXNUM);
    if (argn > 0) {
      r = proc(out, argn, args);
      if (r < 0) {
	LLL_ERRPRINT(r, filename, n);
	if (!readlinebuf) {
	  retcode = -r;
	  break;
	}
      } else if (r > 0) {
	retcode = r - 1;
	break;
      }
    }

    if (readlinebuf)
      free(readlinebuf);
  }

  string_flush(out);

  return retcode;
}
