// Released under the GNU General Public License (see license.txt for details).
//
// Copyright (c) 2011 Chuan-Sheng Foo.
// All Rights Reserved.
//
#include <iostream>
#include <vector>
#include <utility>
#include <deque>
#include <string>
#include <sstream>
#include <fstream>
#include <cstdlib>

using namespace std;

vector<int> randindex;

typedef pair<int,double> VecPair;

// adapted from Pegasos
class SparseVector {
public:
  vector<VecPair> elems; // should make it private someday
  SparseVector() {};
  SparseVector(istringstream& is, int n); // make sparse vector from line in input
  void scale(double s); 
  double normsq() const;
  double dot_product(const vector<double>& y) const;
  int max_index(); 
  void print(ostream& os = cerr); 
};

// n specifies how many non-zero components there are
SparseVector::SparseVector(istringstream& is, int n) {
  for (int i = 0; i < n; i++) {
    int index;
    is >> index;
    double feature;
    is >> feature;
    elems.push_back(make_pair(index - 1, feature)); // off by one bug!! features should start from zero
  }
}

void SparseVector::scale(double s) {
  for (vector<VecPair>::iterator it = elems.begin(); it != elems.end(); ++it) {
    it->second *= s;
  }
}

double SparseVector::normsq() const {
  double ans = 0;
  for (vector<VecPair>::const_iterator it = elems.begin(); it != elems.end(); ++it) {
    ans += it->second * it->second;
  }
  return ans;
}

double SparseVector::dot_product(const vector<double>& y) const {
  double ans = 0;
  for (vector<VecPair>::const_iterator it = elems.begin(); it != elems.end(); ++it) {
    ans += it->second * y[it->first];
  }
  return ans;
}

int SparseVector::max_index() {
  if (elems.begin() != elems.end())
    return (--elems.end())->first;
  else
    return 0;
}

void SparseVector::print(ostream& os) {
  for (vector<VecPair>::iterator it = elems.begin(); it != elems.end(); ++it) {
    os << "(" << it->first << "," << it->second << ") ";  
  }
  os << endl;
}

// Adapted from Pegasos' file input
void predict(vector<double>& w, double lambda, const string& input_filename, const string& output_filename) {  
  int dim = 0;
  int TP = 0;
  int TN = 0;
  int FP = 0;
  int FN = 0;
  double loss = 0;

  ifstream in_file(input_filename.c_str());
  if (!in_file.good()) {
    cerr << "Error opening example file: " << input_filename << endl;
    exit(-1);
  }

  ofstream out_file(output_filename.c_str());
  if (!out_file.good()) {
    cerr << "Error opening output file: " << output_filename << endl;
    exit(-1);
  }

  int num_examples = 0;
  string buf;

  while (getline(in_file, buf)) {
    if (buf[0] == '#') continue;
    size_t pos = buf.find('#');
    if (pos < buf.size()) buf.erase(pos);

    int num_nonzero = 0;
    for (size_t i = 0; i < buf.size(); i++)
      if (buf[i] == ':') {
	num_nonzero++;
	buf[i] = ' ';
      }

    istringstream iss(buf);
    int label = 0;
    iss >> label;
    if (label != 1 && label != -1) {
      cerr << "Invalid Class Label: Class label must be +1 or -1" << endl;
      exit(-1);
    }

    SparseVector example(iss, num_nonzero);

    double wTx = example.dot_product(w);

    loss += max(0.0, 1 - label*wTx);

    int predlabel = (wTx >= 0) ? 1 : -1; 

    out_file << predlabel << endl;

    if (predlabel == label) {
      if (label == 1)
	TP++;
      else
	TN++;
    } else {
      if (label == 1)
	FN++;
      else
	FP++;
    }

    dim = max(dim, example.max_index());
    num_examples++;
  }

  dim++;

  cerr << num_examples << " examples read, dim = " << dim << endl;
  
  out_file.close();
  in_file.close();

  loss /= num_examples;
  double reg = 0;
  for (int i = 0; i < w.size(); i++)
    reg += w[i]*w[i];
  reg *= 0.5 * lambda;
  
  cerr << "Classification results:" << endl;
  cerr << "TP = " << TP << endl;
  cerr << "FP = " << FP << endl;
  cerr << "TN = " << TN << endl;
  cerr << "FN = " << FN << endl;
  cerr << "Accuracy = " << (double)(TP+TN)/(num_examples) << endl;
  cerr << "Sensitivity = " << (double)(TP)/(TP+FN) << endl;
  cerr << "Specificity = " << (double)(TP)/(TP+FP) << endl;
  cerr << "Loss = " << reg + loss << endl;
  cout << reg + loss << ' ' << (double)(TP+TN)/(num_examples) << endl;
}



void print_usage() {
  cerr << "SVM Predict" << endl;
  cerr << "Written by Chuan-Sheng Foo" << endl;
  cerr << "--------------------------" << endl;
  cerr << "Usage:" << endl;
  cerr << "  predict <model_file> <input_file> <output_file>" << endl << endl;
}

void read_model_file(const string& model_filename, vector<double>& w, double& lambda) {
  ifstream inf(model_filename.c_str());
  if (!inf.good()) {
    cerr << "Error opening model file: " << model_filename << endl;
    exit(-1);
  }

  inf >> lambda;

  int dim;

  inf >> dim;
  w.resize(dim);
  
  for (int i = 0; i < dim; i++) {
    inf >> w[i];
  }

  inf.close();
}

int main(int argc, char ** argv) {
  if (argc != 4) {
    print_usage();
    return -1;
  }

  string model_filename = argv[1];
  string input_filename = argv[2];
  string output_filename = argv[3];
  
  cerr << "predict called with parameters: " << endl;
  cerr << "  model file: " << model_filename << endl;
  cerr << "  input file: " << input_filename << endl;
  cerr << "  output file: " << output_filename << endl;
  cerr << endl;
  
  vector<double> w;
  double lambda;
  
  cerr << "reading model file..."; 
  read_model_file(model_filename, w, lambda);
  cerr << " done" << endl;
 
  cerr << "predicting labels..." << endl;
  predict(w, lambda, input_filename, output_filename);

  return 0;
}
