#!/usr/bin/env perl
#
# This file is part of moses.  Its use is licensed under the GNU Lesser General
# Public License version 2.1 or, at your option, any later version.
use utf8;

###############################################
# Adaptation of the implementation of paired bootstrap resampling for testing the statistical
# significance of the difference between two systems from (Koehn 2004 @ EMNLP), developed
# by Mark Fishel, fishel@ut.ee 
# (https://github.com/moses-smt/mosesdecoder/blob/master/scripts/analysis/bootstrap-hypothesis-difference-significance.pl)
#
# Usage: ./compare-hypotheses-with-significance.pl hypothesis_1 hypothesis_2 reference_1 score_function
#
# score_function: 0 - MAE (for scoring), 1 - DeltaAvg (for ranking), 2 - F-measure (for word-level)
#
# input files: one prediction value per line (except word-level prediction - use WMT official)
#
# Author: Carolina Scarton, c.scarton@sheffield.ac.uk
#
###############################################

use warnings;
#use strict;

#constants
my $TIMES_TO_REPEAT_SUBSAMPLING = 1000;
my $SUBSAMPLE_SIZE = 0; # if 0 then subsample size is equal to the whole set
my $MAX_NGRAMS = 4;
my $IO_ENCODING = "utf8"; # can be replaced with e.g. "encoding(iso-8859-13)" or alike
my $INF = 1000000;
my $DEBUG = 0;

#checking cmdline argument consistency
if (@ARGV < 4) {
	print STDERR "Usage: ./bootstrap-hypothesis-difference-significance.pl hypothesis_1 hypothesis_2 reference_1 score_function\n0 - MAE (scoring)\n1 - DeltaAvg (ranking)\n";

	unless ($ARGV[0] =~ /^(--help|-help|-h|-\?|\/\?|--usage|-usage)$/) {
		die("\nERROR: not enough arguments");
	}

	exit 1;
}

print STDERR "reading data; " . `date`;

#read all data
my $data = readAllData(@ARGV);

my $verbose = $ARGV[4];

my $opOption = $ARGV[3];

#calculate each sentence's contribution to BP and ngram precision
print STDERR "performing preliminary calculations (hypothesis 1); " . `date`;
#preEvalHypo($data, "hyp1");

print STDERR "performing preliminary calculations (hypothesis 2); " . `date`;
#preEvalHypo($data, "hyp2");

#start comparing
print STDERR "comparing hypotheses -- this may take some time; " . `date`;

if (scalar $opOption == 1){bootstrap_report("DeltaAvg", \&avgDelta);}
elsif (scalar $opOption == 0){bootstrap_report("MAE",\&MAE_RMSE);}

#####
#
#####
sub bootstrap_report {
	my $title = shift;
	my $proc = shift;

	my ($subSampleScoreDiffArr, $subSampleScore1Arr, $subSampleScore2Arr) = bootstrap_pass($proc);


	my $realScore1 = &$proc($data->{refs}, $data->{hyp1});
	my $realScore2 = &$proc($data->{refs}, $data->{hyp2});

	my $scorePValue = bootstrap_pvalue($subSampleScoreDiffArr, $realScore1, $realScore2);

	my ($scoreAvg1, $scoreVar1) = bootstrap_interval($subSampleScore1Arr);
	my ($scoreAvg2, $scoreVar2) = bootstrap_interval($subSampleScore2Arr);

	print "\n---=== $title score ===---\n";

	print "actual score of hypothesis 1: $realScore1\n";
	print "95% confidence interval for hypothesis 1 score: $scoreAvg1 +- $scoreVar1\n-----\n";
	print "actual score of hypothesis 1: $realScore2\n";
	print "95% confidence interval for hypothesis 2 score: $scoreAvg2 +- $scoreVar2\n-----\n";
	print "Assuming that essentially the same system generated the two hypothesis translations (null-hypothesis),\n";
	printf ("the probability of actually getting them (p-value) is: %.4f\n", $scorePValue);
}

#####
#
#####
sub bootstrap_pass {
	my $scoreFunc = shift;

	my @subSampleDiffArr;
	my @subSample1Arr;
	my @subSample2Arr;

	#applying sampling
	for my $idx (1..$TIMES_TO_REPEAT_SUBSAMPLING) {
		my $subSampleIndices = drawWithReplacement($data->{size}, ($SUBSAMPLE_SIZE? $SUBSAMPLE_SIZE: $data->{size}));
        #print $subSampleIndices;		

		my $score1 = &$scoreFunc($data->{refs}, $data->{hyp1}, $subSampleIndices);
		my $score2 = &$scoreFunc($data->{refs}, $data->{hyp2}, $subSampleIndices);

		push @subSampleDiffArr, abs($score2 - $score1);
		push @subSample1Arr, $score1;
		push @subSample2Arr, $score2;

		if ($idx % 10 == 0) {
			print STDERR ".";
		}
		if ($idx % 100 == 0) {
			print STDERR "$idx\n";
		}
	}

	if ($TIMES_TO_REPEAT_SUBSAMPLING % 100 != 0) {
		print STDERR ".$TIMES_TO_REPEAT_SUBSAMPLING\n";
	}

	return (\@subSampleDiffArr, \@subSample1Arr, \@subSample2Arr);
}

#####
#
#####
sub bootstrap_pvalue {
	my $subSampleDiffArr = shift;
	my $realScore1 = shift;
	my $realScore2 = shift;
	
	my $realDiff = abs($realScore2 - $realScore1);

	#get subsample difference mean
	my $averageSubSampleDiff = 0;

	for my $subSampleDiff (@$subSampleDiffArr) {
		$averageSubSampleDiff += $subSampleDiff;
	}
	
	$averageSubSampleDiff /= $TIMES_TO_REPEAT_SUBSAMPLING;
	
	#calculating p-value
	my $count = 0;

	my $realScoreDiff = abs($realScore2 - $realScore1);

	for my $subSampleDiff (@$subSampleDiffArr) {
		#printf("avg: %f\n",$averageSubSampleDiff);	
		#printf("subdiff: %f\n",$subSampleDiff);		
		if ($subSampleDiff - $averageSubSampleDiff >= $realDiff) {
			$count++;
		}
	}
	
	return $count / $TIMES_TO_REPEAT_SUBSAMPLING;
}

#####
#
#####
sub bootstrap_interval {
	my $subSampleArr = shift;

	my @sorted = sort @$subSampleArr;

	my $lowerIdx = int($TIMES_TO_REPEAT_SUBSAMPLING / 40);
	my $higherIdx = $TIMES_TO_REPEAT_SUBSAMPLING - $lowerIdx - 1;

	my $lower = $sorted[$lowerIdx];
	my $higher = $sorted[$higherIdx];
	my $diff = $higher - $lower;

	#printf("l: %f\n",$lower);	
	#printf("h: %f\n",$higher);	

	return ($lower + 0.5 * $diff, 0.5 * $diff);
}

#####
# read 2 hyp and 1 to \infty ref data files
#####
sub readAllData {
	my ($hypFile1, $hypFile2, $refFiles) = @_;

	my %result;

	#reading hypotheses and checking for matching sizes
	$result{hyp1} = readData($hypFile1);
	$result{size} = scalar @{$result{hyp1}};

	$result{hyp2} = readData($hypFile2);
	unless (scalar @{$result{hyp2}} == $result{size}) {
		die ("ERROR: sizes of hypothesis sets 1 and 2 don't match");
	}

	#reading reference(s) and checking for matching sizes
	$result{refs} = readData($refFiles);
	
	return \%result;
}


#####
# read sentences from file
#####
sub readData {
	my $file = shift;
	my @result;

	open (FILE, $file) or die ("Failed to open `$file' for reading");
	binmode (FILE, ":$IO_ENCODING");

	while (my $line=<FILE>) {
		#printf("%s", $line);
		push @result, $line;
	}

	close (FILE);

	return \@result;
}

#####
# draw a subsample of size $subSize from set (0..$setSize) with replacement
#####
sub drawWithReplacement {
	my ($setSize, $subSize) = @_;

	my @result;

	for (1..$subSize) {
		push @result, int(rand($setSize));
	}

	return \@result;
}

#####
# Above functions (MAE_RMSE and avgDelta)  are adaptation from Radu Soricut's evaluation script: http://www.quest.dcs.shef.ac.uk/wmt13_files/evaluateWMTQP2013-Task1_1.pl
#
# Version 1.2
# For research or educational purposes only.  Do not redistribute.
#####


#####
# Adaptation of MAE implementation by Radu Soricut - WMT QE shared task official evaluation script
# http://www.quest.dcs.shef.ac.uk/wmt13_files/evaluateWMTQP2013-Task1_1.pl
#
#####
sub MAE_RMSE{

	my ($refs, $hyp, $idxs) = @_;
	#my ($n) = 414;

	#default value for $idxs
	unless (defined($idxs)) {
		$idxs = [0..((scalar @$hyp) - 1)];
	}

	my ($ESUM,$SSUM,$zeros,$errSm,$errLg,$n) = (0,0,0,0,0);
	my ($minInterv, $maxInterv) = ($INF,-$INF);
	
	for my $lineIdx (@$idxs) {    
		my $hypSnt = scalar $hyp->[$lineIdx];
		#printf("%.2f\n",scalar $hypSnt);
		my $refSnt = scalar $refs->[$lineIdx];
		if( $minInterv>$hypSnt ){ $minInterv=$hypSnt; }
		if( $maxInterv<$hypSnt ){ $maxInterv=$hypSnt; }
		my $err = $hypSnt-$refSnt;
		my $aerr = abs($err);
		$ESUM += $aerr;
		$SSUM += $err*$err;
    	}
	#printf("IDXX: %.2f\f",scalar @$idxs);
	my $MAE = $ESUM/scalar @$idxs;

	my $RMSE = sqrt($SSUM/scalar @$idxs);

	#printf("%.2f\n",scalar $ESUM);
	#return ($MAE,$RMSE,$minInterv,$maxInterv);
	
	return ($MAE);	
}

#####
# Adapation of DeltaAvg implementation by Radu Soricut - WMT QE shared task official evaluation script
# http://www.quest.dcs.shef.ac.uk/wmt13_files/evaluateWMTQP2013-Task1_1.pl
#
#####
sub avgDelta{
	my ($refs, $hyp, $idxs) = @_;
	#default value for $idxs
	unless (defined($idxs)) {
		$idxs = [0..((scalar @$hyp) - 1)];
	}
	my @newRef;
	my @newHyp;
	my %hash2 = ();
	my %hash = ();
	for my $lineIdx (@$idxs) {
		$hash{$lineIdx+1} = scalar $hyp->[$lineIdx];
		$hash2{$lineIdx+1} = scalar $refs->[$lineIdx];
	}
	
	my @inputSortIdx = sort { $hash{$a} <=> $hash{$b} }( keys %hash );


	my $ridx = scalar(@inputSortIdx);
	
    	my @refValueSort = ();
    	my $refSum = 0;
    	my @avgDelta = ();
    	my $AvgDelta = 0;
    	my $cN = 0;
    	my $maxN = int($ridx/2);

    	for($cN=2; $cN<=$maxN; $cN++){ # current number of quantiles
		@refValueSort = ();
		$refSum = 0;
		for(my $i=1; $i<=$cN; $i++){
	    		my $q = int($ridx/$cN);
	    		my $head = $i*$q;
	    		if( $i==$cN && $head<$ridx ){ $head = $ridx; } # include the remainder, so that the average is done across the entire input
	    		for(my $k=0; $k<$head; $k++){ $refValueSort[$i] += $hash2{$inputSortIdx[$k]}; }
			
			$refValueSort[$i] /= $head;
			
	    		if( $i<$cN ){ $refSum += $refValueSort[$i]; }
	    		printf STDERR "Avg. RefValues-over-quantile(s) 1..$i: %.2f\n", $refValueSort[$i] if $DEBUG>1;
		}
		$avgDelta[$cN] = $refSum/($cN-1)-$refValueSort[$cN];
		
		printf STDERR "AvgDelta[$cN]: %.2f\n", $avgDelta[$cN] if $DEBUG>0;
		$AvgDelta += $avgDelta[$cN];
    	}
    	if( $maxN>1 ){
		$AvgDelta /= ($maxN-1);
    	}
    	else{ $AvgDelta = 0; }
    	return abs($AvgDelta);

}

