/* bwbench v0.1 Copyright (C) 2004 Thomas "tom" S. <tom@eggdrop.ch>
 * 
 * a tool for measuring network bandwidth
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include <sys/time.h>

/* sockets */
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define INTERVAL 1000
#define BUFSIZE 32768

long gettime(void)
{
	long currenttime;
	struct timeval tp;
	struct timezone tzp;
	static int secbase=0;
	gettimeofday(&tp, &tzp);
	if (!secbase)
	{       
		secbase=tp.tv_sec;
		currenttime = tp.tv_usec/1000;
	}	       
	else    
	{	       
		currenttime = (tp.tv_sec - secbase)*1000 + tp.tv_usec/1000;
	}
	return currenttime;
}

void usage(char *prog)
{
	printf("Usage: %s <host> <port>\n", prog);
	exit(1);
}

void mainloop(int sock)
{
	long timelast, timenow;
	struct timeval tv;
	fd_set fds;
	int tmp, bytes;
	char buffer[BUFSIZE];

	tv.tv_sec = 0;
	tv.tv_usec = 0;

	timelast = gettime();

	bytes = 0;

	while (1)
	{
		timenow = gettime();

		FD_ZERO(&fds);
		FD_SET(sock, &fds);
		FD_SET(STDIN_FILENO, &fds);
		
		while ((tmp = select(sock+1, &fds, NULL, NULL, &tv)) != 0)
		{
			if ((tmp = read(sock, buffer, \
					sizeof(buffer))) < 1)
			{
				printf("Connection closed\n");
				exit(5);
			}
			bytes += tmp;
		}

		if (timenow - timelast > INTERVAL)
		{
			/* commands here will be executed every
			   INTERVAL/1000 secs */

			if (bytes*8 < 1024)
				printf("%lf bps ", (double)bytes*8);
			else if (bytes*8 < 1024*1024)
				printf("%lf Kbps ", (bytes*8)/1024.0);
			else
				printf("%lf Mbps ", (bytes*8)/(1024.0*1024.0));
			
			if (bytes < 1024)
				printf("(%lf B/s)\n", (double)bytes);
			else if (bytes < 1024*1024)
				printf("(%lf KB/s)\n", bytes/1024.0);
			else
				printf("(%lf MB/s)\n", bytes/(1024.0*1024.0));
			
			bytes = 0;
			timelast = timenow;
		}
	}

	return;
}

int setup_connection(char *host, unsigned short port)
{
	struct hostent *haddr;
	struct sockaddr_in in;
	int sock;

	haddr = gethostbyname(host);
	if (!haddr)
	{
		printf("Unable to resolve %s\n", host);
		exit(2);
	}

	in.sin_addr = *((struct in_addr*)haddr->h_addr);
	in.sin_family = AF_INET;
	in.sin_port = htons(port);

	printf("Connecting to %s (%s) port %d ... ", host, \
			inet_ntoa(in.sin_addr), port);
	fflush(stdout);

	sock = socket(AF_INET, SOCK_STREAM, 0);
	if (!sock)
	{
		printf("failed\nUnable to create socket\n");
		exit(3);
	}

	if (connect(sock, (struct sockaddr*)&in, sizeof(in)) < 0)
	{
		printf("failed\nUnable to connect to remote host\n");
		exit(4);
	}

	printf("ok\n");

	return sock;
}

int main(int argc, char *argv[])
{
	unsigned short port;
	char *host;
	
	if (argc != 3)
		usage(argv[0]);

	port = (unsigned short)atoi(argv[2]);
	if (!port)
		usage(argv[0]);

	host = argv[1];

	mainloop(setup_connection(host, port));
	
	return 0;
}
