#ifdef __linux__
#define _BSD_SOURCE
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#ifndef USE_NETLIB
#include <net/ethernet.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#else
#include <netlib.h>
#endif

#include "pktlib.h"
#include "pktbuf.h"

int pktlib_ip_checksum(void *buffer, int size)
{
  union {
    char c[2];
    unsigned short s;
  } w;
  char *p;
  int sum = 0;

  for (p = buffer; size > 0; p += 2) {
    w.c[0] = p[0];
    w.c[1] = (size > 1) ? p[1] : 0;
    sum += w.s; /* IPチェックサム計算は両エンディアンでOKなのでntohs()は不要 */
    size -= 2;
  }
  sum = (sum & 0xffff) + (sum >> 16);
  sum = (sum & 0xffff) + (sum >> 16);

  return sum;
}

struct pseudo_header {
  in_addr_t saddr;
  in_addr_t daddr;
  unsigned char zero;
  unsigned char protocol;
  unsigned short len;
};

static pktbuf_t correct_icmp(pktbuf_t pktbuf, int size)
{
  char *p;
  struct icmp *icmphdr;

  p = pktbuf_get_header(pktbuf);

  icmphdr = (struct icmp *)p;
  icmphdr->icmp_cksum = 0;
  icmphdr->icmp_cksum = ~pktlib_ip_checksum(p, size); /* Unneed htons() */

  return pktbuf;
}

static pktbuf_t correct_tcp(pktbuf_t pktbuf, int size, int pchksum)
{
  char *p;
  struct tcphdr *tcphdr;

  p = pktbuf_get_header(pktbuf);

  tcphdr = (struct tcphdr *)p;
  tcphdr->th_sum = pchksum; /* Unneed htons() */
  tcphdr->th_sum = ~pktlib_ip_checksum(p, size); /* Unneed htons() */

  return pktbuf;
}

static pktbuf_t correct_udp(pktbuf_t pktbuf, int size, int pchksum)
{
  char *p;
  struct udphdr *udphdr;

  p = pktbuf_get_header(pktbuf);

  udphdr = (struct udphdr *)p;
  udphdr->uh_sum = pchksum; /* Unneed htons() */
  udphdr->uh_sum = ~pktlib_ip_checksum(p, size); /* Unneed htons() */

  return pktbuf;
}

pktbuf_t pktbuf_checksum_correct_ip(pktbuf_t pktbuf)
{
  char *p;
  struct ip *iphdr;
  int hdrsize, paysize;
  struct pseudo_header phdr;
  int pchksum;

  p = pktbuf_get_header(pktbuf);

  iphdr = (struct ip *)p;
  hdrsize = iphdr->ip_hl << 2;
  paysize = ntohs(iphdr->ip_len) - hdrsize;

  memset(&phdr, 0, sizeof(phdr));
  phdr.saddr = iphdr->ip_src.s_addr;
  phdr.daddr = iphdr->ip_dst.s_addr;
  phdr.protocol = iphdr->ip_p;
  phdr.len = htons(paysize);
  pchksum = pktlib_ip_checksum(&phdr, sizeof(phdr));

  pktbuf_delete_header(pktbuf, hdrsize);
  switch (iphdr->ip_p) {
  case IPPROTO_ICMP: pktbuf = correct_icmp(pktbuf, paysize); break;
  case IPPROTO_TCP:  pktbuf = correct_tcp(pktbuf, paysize, pchksum); break;
  case IPPROTO_UDP:  pktbuf = correct_udp(pktbuf, paysize, pchksum); break;
  default: break;
  }
  if (pktbuf == NULL)
    return NULL;
  pktbuf_add_header(pktbuf, hdrsize);

  p = pktbuf_get_header(pktbuf);

  iphdr = (struct ip *)p;
  iphdr->ip_sum = 0;
  iphdr->ip_sum = ~pktlib_ip_checksum(p, hdrsize);

  return pktbuf;
}

pktbuf_t pktbuf_checksum_correct(pktbuf_t pktbuf)
{
  char *p;
  struct ether_header *ehdr;

  p = pktbuf_get_header(pktbuf);

  ehdr = (struct ether_header *)p;

  pktbuf_delete_header(pktbuf, ETHER_HDR_LEN);
  switch (ntohs(ehdr->ether_type)) {
  case ETHERTYPE_IP: pktbuf = pktbuf_checksum_correct_ip(pktbuf); break;
  default: break;
  }
  if (pktbuf == NULL)
    return NULL;
  pktbuf_add_header(pktbuf, ETHER_HDR_LEN);

  return pktbuf;
}
