/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "old_ml_precomp.hpp"
#include <ctype.h>

#define MISS_VAL    FLT_MAX
#define CV_VAR_MISS    0

CvTrainTestSplit::CvTrainTestSplit()
{
    train_sample_part_mode = CV_COUNT;
    train_sample_part.count = -1;
    mix = false;
}

CvTrainTestSplit::CvTrainTestSplit( int _train_sample_count, bool _mix )
{
    train_sample_part_mode = CV_COUNT;
    train_sample_part.count = _train_sample_count;
    mix = _mix;
}

CvTrainTestSplit::CvTrainTestSplit( float _train_sample_portion, bool _mix )
{
    train_sample_part_mode = CV_PORTION;
    train_sample_part.portion = _train_sample_portion;
    mix = _mix;
}

////////////////

CvMLData::CvMLData()
{
    values = missing = var_types = var_idx_mask = response_out = var_idx_out = var_types_out = 0;
    train_sample_idx = test_sample_idx = 0;
    header_lines_number = 0;
    sample_idx = 0;
    response_idx = -1;

    train_sample_count = -1;

    delimiter = ',';
    miss_ch = '?';
    //flt_separator = '.';

    rng = &cv::theRNG();
}

CvMLData::~CvMLData()
{
    clear();
}

void CvMLData::free_train_test_idx()
{
    cvReleaseMat( &train_sample_idx );
    cvReleaseMat( &test_sample_idx );
    sample_idx = 0;
}

void CvMLData::clear()
{
    class_map.clear();

    cvReleaseMat( &values );
    cvReleaseMat( &missing );
    cvReleaseMat( &var_types );
    cvReleaseMat( &var_idx_mask );

    cvReleaseMat( &response_out );
    cvReleaseMat( &var_idx_out );
    cvReleaseMat( &var_types_out );

    free_train_test_idx();

    total_class_count = 0;

    response_idx = -1;

    train_sample_count = -1;
}


void CvMLData::set_header_lines_number( int idx )
{
    header_lines_number = std::max(0, idx);
}

int CvMLData::get_header_lines_number() const
{
    return header_lines_number;
}

static char *fgets_chomp(char *str, int n, FILE *stream)
{
    char *head = fgets(str, n, stream);
    if( head )
    {
        for(char *tail = head + strlen(head) - 1; tail >= head; --tail)
        {
            if( *tail != '\r'  && *tail != '\n' )
                break;
            *tail = '\0';
        }
    }
    return head;
}


int CvMLData::read_csv(const char* filename)
{
    const int M = 1000000;
    const char str_delimiter[3] = { ' ', delimiter, '\0' };
    FILE* file = 0;
    CvMemStorage* storage;
    CvSeq* seq;
    char *ptr;
    float* el_ptr;
    CvSeqReader reader;
    int cols_count = 0;
    uchar *var_types_ptr = 0;

    clear();

    file = fopen( filename, "rt" );

    if( !file )
        return -1;

    std::vector<char> _buf(M);
    char* buf = &_buf[0];

    // skip header lines
    for( int i = 0; i < header_lines_number; i++ )
    {
        if( fgets( buf, M, file ) == 0 )
        {
            fclose(file);
            return -1;
        }
    }

    // read the first data line and determine the number of variables
    if( !fgets_chomp( buf, M, file ))
    {
        fclose(file);
        return -1;
    }

    ptr = buf;
    while( *ptr == ' ' )
        ptr++;
    for( ; *ptr != '\0'; )
    {
        if(*ptr == delimiter || *ptr == ' ')
        {
            cols_count++;
            ptr++;
            while( *ptr == ' ' ) ptr++;
        }
        else
            ptr++;
    }

    cols_count++;

    if ( cols_count == 0)
    {
        fclose(file);
        return -1;
    }

    // create temporary memory storage to store the whole database
    el_ptr = new float[cols_count];
    storage = cvCreateMemStorage();
    seq = cvCreateSeq( 0, sizeof(*seq), cols_count*sizeof(float), storage );

    var_types = cvCreateMat( 1, cols_count, CV_8U );
    cvZero( var_types );
    var_types_ptr = var_types->data.ptr;

    for(;;)
    {
        char *token = NULL;
        int type;
        token = strtok(buf, str_delimiter);
        if (!token)
            break;
        for (int i = 0; i < cols_count-1; i++)
        {
            str_to_flt_elem( token, el_ptr[i], type);
            var_types_ptr[i] |= type;
            token = strtok(NULL, str_delimiter);
            if (!token)
            {
                fclose(file);
                delete [] el_ptr;
                return -1;
            }
        }
        str_to_flt_elem( token, el_ptr[cols_count-1], type);
        var_types_ptr[cols_count-1] |= type;
        cvSeqPush( seq, el_ptr );
        if( !fgets_chomp( buf, M, file ) )
            break;
    }
    fclose(file);

    values = cvCreateMat( seq->total, cols_count, CV_32FC1 );
    missing = cvCreateMat( seq->total, cols_count, CV_8U );
    var_idx_mask = cvCreateMat( 1, values->cols, CV_8UC1 );
    cvSet( var_idx_mask, cvRealScalar(1) );
    train_sample_count = seq->total;

    cvStartReadSeq( seq, &reader );
    for(int i = 0; i < seq->total; i++ )
    {
        const float* sdata = (float*)reader.ptr;
        float* ddata = values->data.fl + cols_count*i;
        uchar* dm = missing->data.ptr + cols_count*i;

        for( int j = 0; j < cols_count; j++ )
        {
            ddata[j] = sdata[j];
            dm[j] = ( fabs( MISS_VAL - sdata[j] ) <= FLT_EPSILON );
        }
        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    }

    if ( cvNorm( missing, 0, CV_L1 ) <= FLT_EPSILON )
        cvReleaseMat( &missing );

    cvReleaseMemStorage( &storage );
    delete []el_ptr;
    return 0;
}

const CvMat* CvMLData::get_values() const
{
    return values;
}

const CvMat* CvMLData::get_missing() const
{
    CV_FUNCNAME( "CvMLData::get_missing" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    __END__;

    return missing;
}

const std::map<cv::String, int>& CvMLData::get_class_labels_map() const
{
    return class_map;
}

void CvMLData::str_to_flt_elem( const char* token, float& flt_elem, int& type)
{

    char* stopstring = NULL;
    flt_elem = (float)strtod( token, &stopstring );
    assert( stopstring );
    type = CV_VAR_ORDERED;
    if ( *stopstring == miss_ch && strlen(stopstring) == 1 ) // missed value
    {
        flt_elem = MISS_VAL;
        type = CV_VAR_MISS;
    }
    else
    {
        if ( (*stopstring != 0) && (*stopstring != '\n') && (strcmp(stopstring, "\r\n") != 0) ) // class label
        {
            int idx = class_map[token];
            if ( idx == 0)
            {
                total_class_count++;
                idx = total_class_count;
                class_map[token] = idx;
            }
            flt_elem = (float)idx;
            type = CV_VAR_CATEGORICAL;
        }
    }
}

void CvMLData::set_delimiter(char ch)
{
    CV_FUNCNAME( "CvMLData::set_delimited" );
    __BEGIN__;

    if (ch == miss_ch /*|| ch == flt_separator*/)
        CV_ERROR(CV_StsBadArg, "delimited, miss_character and flt_separator must be different");

    delimiter = ch;

    __END__;
}

char CvMLData::get_delimiter() const
{
    return delimiter;
}

void CvMLData::set_miss_ch(char ch)
{
    CV_FUNCNAME( "CvMLData::set_miss_ch" );
    __BEGIN__;

    if (ch == delimiter/* || ch == flt_separator*/)
        CV_ERROR(CV_StsBadArg, "delimited, miss_character and flt_separator must be different");

    miss_ch = ch;

    __END__;
}

char CvMLData::get_miss_ch() const
{
    return miss_ch;
}

void CvMLData::set_response_idx( int idx )
{
    CV_FUNCNAME( "CvMLData::set_response_idx" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    if ( idx >= values->cols)
        CV_ERROR( CV_StsBadArg, "idx value is not correct" );

    if ( response_idx >= 0 )
        chahge_var_idx( response_idx, true );
    if ( idx >= 0 )
        chahge_var_idx( idx, false );
    response_idx = idx;

    __END__;
}

int CvMLData::get_response_idx() const
{
    CV_FUNCNAME( "CvMLData::get_response_idx" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );
     __END__;
    return response_idx;
}

void CvMLData::change_var_type( int var_idx, int type )
{
    CV_FUNCNAME( "CvMLData::change_var_type" );
    __BEGIN__;

    int var_count = 0;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

     var_count = values->cols;

    if ( var_idx < 0 || var_idx >= var_count)
        CV_ERROR( CV_StsBadArg, "var_idx is not correct" );

    if ( type != CV_VAR_ORDERED && type != CV_VAR_CATEGORICAL)
         CV_ERROR( CV_StsBadArg, "type is not correct" );

    assert( var_types );
    if ( var_types->data.ptr[var_idx] == CV_VAR_CATEGORICAL && type == CV_VAR_ORDERED)
        CV_ERROR( CV_StsBadArg, "it`s impossible to assign CV_VAR_ORDERED type to categorical variable" );
    var_types->data.ptr[var_idx] = (uchar)type;

    __END__;

    return;
}

void CvMLData::set_var_types( const char* str )
{
    CV_FUNCNAME( "CvMLData::set_var_types" );
    __BEGIN__;

    const char* ord = 0, *cat = 0;
    int var_count = 0, set_var_type_count = 0;
    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    var_count = values->cols;

    assert( var_types );

    ord = strstr( str, "ord" );
    cat = strstr( str, "cat" );
    if ( !ord && !cat )
        CV_ERROR( CV_StsBadArg, "types string is not correct" );

    if ( !ord && strlen(cat) == 3 ) // str == "cat"
    {
        cvSet( var_types, cvScalarAll(CV_VAR_CATEGORICAL) );
        return;
    }

    if ( !cat && strlen(ord) == 3 ) // str == "ord"
    {
        cvSet( var_types, cvScalarAll(CV_VAR_ORDERED) );
        return;
    }

    if ( ord ) // parse ord str
    {
        char* stopstring = NULL;
        if ( ord[3] != '[')
            CV_ERROR( CV_StsBadArg, "types string is not correct" );

        ord += 4; // pass "ord["
        do
        {
            int b1 = (int)strtod( ord, &stopstring );
            if ( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
                CV_ERROR( CV_StsBadArg, "types string is not correct" );
            ord = stopstring + 1;
            if ( (stopstring[0] == ',') || (stopstring[0] == ']'))
            {
                if ( var_types->data.ptr[b1] == CV_VAR_CATEGORICAL)
                    CV_ERROR( CV_StsBadArg, "it`s impossible to assign CV_VAR_ORDERED type to categorical variable" );
                var_types->data.ptr[b1] = CV_VAR_ORDERED;
                set_var_type_count++;
            }
            else
            {
                if ( stopstring[0] == '-')
                {
                    int b2 = (int)strtod( ord, &stopstring);
                    if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
                        CV_ERROR( CV_StsBadArg, "types string is not correct" );
                    ord = stopstring + 1;
                    for (int i = b1; i <= b2; i++)
                    {
                        if ( var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
                            CV_ERROR( CV_StsBadArg, "it`s impossible to assign CV_VAR_ORDERED type to categorical variable" );
                        var_types->data.ptr[i] = CV_VAR_ORDERED;
                    }
                    set_var_type_count += b2 - b1 + 1;
                }
                else
                    CV_ERROR( CV_StsBadArg, "types string is not correct" );

            }
        }
        while (*stopstring != ']');

        if ( stopstring[1] != '\0' && stopstring[1] != ',')
            CV_ERROR( CV_StsBadArg, "types string is not correct" );
    }

    if ( cat ) // parse cat str
    {
        char* stopstring = NULL;
        if ( cat[3] != '[')
            CV_ERROR( CV_StsBadArg, "types string is not correct" );

        cat += 4; // pass "cat["
        do
        {
            int b1 = (int)strtod( cat, &stopstring );
            if ( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
                CV_ERROR( CV_StsBadArg, "types string is not correct" );
            cat = stopstring + 1;
            if ( (stopstring[0] == ',') || (stopstring[0] == ']'))
            {
                var_types->data.ptr[b1] = CV_VAR_CATEGORICAL;
                set_var_type_count++;
            }
            else
            {
                if ( stopstring[0] == '-')
                {
                    int b2 = (int)strtod( cat, &stopstring);
                    if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
                        CV_ERROR( CV_StsBadArg, "types string is not correct" );
                    cat = stopstring + 1;
                    for (int i = b1; i <= b2; i++)
                        var_types->data.ptr[i] = CV_VAR_CATEGORICAL;
                    set_var_type_count += b2 - b1 + 1;
                }
                else
                    CV_ERROR( CV_StsBadArg, "types string is not correct" );

            }
        }
        while (*stopstring != ']');

        if ( stopstring[1] != '\0' && stopstring[1] != ',')
            CV_ERROR( CV_StsBadArg, "types string is not correct" );
    }

    if (set_var_type_count != var_count)
        CV_ERROR( CV_StsBadArg, "types string is not correct" );

     __END__;
}

const CvMat* CvMLData::get_var_types()
{
    CV_FUNCNAME( "CvMLData::get_var_types" );
    __BEGIN__;

    uchar *var_types_out_ptr = 0;
    int avcount, vt_size;
    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    assert( var_idx_mask );

    avcount = cvFloor( cvNorm( var_idx_mask, 0, CV_L1 ) );
    vt_size = avcount + (response_idx >= 0);

    if ( avcount == values->cols || (avcount == values->cols-1 && response_idx == values->cols-1) )
        return var_types;

    if ( !var_types_out || ( var_types_out && var_types_out->cols != vt_size ) )
    {
        cvReleaseMat( &var_types_out );
        var_types_out = cvCreateMat( 1, vt_size, CV_8UC1 );
    }

    var_types_out_ptr = var_types_out->data.ptr;
    for( int i = 0; i < var_types->cols; i++)
    {
        if (i == response_idx || !var_idx_mask->data.ptr[i]) continue;
        *var_types_out_ptr = var_types->data.ptr[i];
        var_types_out_ptr++;
    }
    if ( response_idx >= 0 )
        *var_types_out_ptr = var_types->data.ptr[response_idx];

    __END__;

    return var_types_out;
}

int CvMLData::get_var_type( int var_idx ) const
{
    return var_types->data.ptr[var_idx];
}

const CvMat* CvMLData::get_responses()
{
    CV_FUNCNAME( "CvMLData::get_responses_ptr" );
    __BEGIN__;

    int var_count = 0;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );
    var_count = values->cols;

    if ( response_idx < 0 || response_idx >= var_count )
       return 0;
    if ( !response_out )
        response_out = cvCreateMatHeader( values->rows, 1, CV_32FC1 );
    else
        cvInitMatHeader( response_out, values->rows, 1, CV_32FC1);
    cvGetCol( values, response_out, response_idx );

    __END__;

    return response_out;
}

void CvMLData::set_train_test_split( const CvTrainTestSplit * spl)
{
    CV_FUNCNAME( "CvMLData::set_division" );
    __BEGIN__;

    int sample_count = 0;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    sample_count = values->rows;

    float train_sample_portion;

    if (spl->train_sample_part_mode == CV_COUNT)
    {
        train_sample_count = spl->train_sample_part.count;
        if (train_sample_count > sample_count)
            CV_ERROR( CV_StsBadArg, "train samples count is not correct" );
        train_sample_count = train_sample_count<=0 ? sample_count : train_sample_count;
    }
    else // dtype.train_sample_part_mode == CV_PORTION
    {
        train_sample_portion = spl->train_sample_part.portion;
        if ( train_sample_portion > 1)
            CV_ERROR( CV_StsBadArg, "train samples count is not correct" );
        train_sample_portion = train_sample_portion <= FLT_EPSILON ||
            1 - train_sample_portion <= FLT_EPSILON ? 1 : train_sample_portion;
        train_sample_count = std::max(1, cvFloor( train_sample_portion * sample_count ));
    }

    if ( train_sample_count == sample_count )
    {
        free_train_test_idx();
        return;
    }

    if ( train_sample_idx && train_sample_idx->cols != train_sample_count )
        free_train_test_idx();

    if ( !sample_idx)
    {
        int test_sample_count = sample_count- train_sample_count;
        sample_idx = (int*)cvAlloc( sample_count * sizeof(sample_idx[0]) );
        for (int i = 0; i < sample_count; i++ )
            sample_idx[i] = i;
        train_sample_idx = cvCreateMatHeader( 1, train_sample_count, CV_32SC1 );
        *train_sample_idx = cvMat( 1, train_sample_count, CV_32SC1, &sample_idx[0] );

        CV_Assert(test_sample_count > 0);
        test_sample_idx = cvCreateMatHeader( 1, test_sample_count, CV_32SC1 );
        *test_sample_idx = cvMat( 1, test_sample_count, CV_32SC1, &sample_idx[train_sample_count] );
    }

    mix = spl->mix;
    if ( mix )
        mix_train_and_test_idx();

    __END__;
}

const CvMat* CvMLData::get_train_sample_idx() const
{
    CV_FUNCNAME( "CvMLData::get_train_sample_idx" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );
    __END__;

    return train_sample_idx;
}

const CvMat* CvMLData::get_test_sample_idx() const
{
    CV_FUNCNAME( "CvMLData::get_test_sample_idx" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );
    __END__;

    return test_sample_idx;
}

void CvMLData::mix_train_and_test_idx()
{
    CV_FUNCNAME( "CvMLData::mix_train_and_test_idx" );
    __BEGIN__;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );
    __END__;

    if ( !sample_idx)
        return;

    if ( train_sample_count > 0 && train_sample_count < values->rows )
    {
        int n = values->rows;
        for (int i = 0; i < n; i++)
        {
            int a = (*rng)(n);
            int b = (*rng)(n);
            int t;
            CV_SWAP( sample_idx[a], sample_idx[b], t );
        }
    }
}

const CvMat* CvMLData::get_var_idx()
{
     CV_FUNCNAME( "CvMLData::get_var_idx" );
    __BEGIN__;

    int avcount = 0;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    assert( var_idx_mask );

    avcount = cvFloor( cvNorm( var_idx_mask, 0, CV_L1 ) );
    int* vidx;

    if ( avcount == values->cols )
        return 0;

    if ( !var_idx_out || ( var_idx_out && var_idx_out->cols != avcount ) )
    {
        cvReleaseMat( &var_idx_out );
        var_idx_out = cvCreateMat( 1, avcount, CV_32SC1);
        if ( response_idx >=0 )
            var_idx_mask->data.ptr[response_idx] = 0;
    }

    vidx = var_idx_out->data.i;

    for(int i = 0; i < var_idx_mask->cols; i++)
        if ( var_idx_mask->data.ptr[i] )
        {
            *vidx = i;
            vidx++;
        }

    __END__;

    return var_idx_out;
}

void CvMLData::chahge_var_idx( int vi, bool state )
{
    change_var_idx( vi, state );
}

void CvMLData::change_var_idx( int vi, bool state )
{
     CV_FUNCNAME( "CvMLData::change_var_idx" );
    __BEGIN__;

    int var_count = 0;

    if ( !values )
        CV_ERROR( CV_StsInternal, "data is empty" );

    var_count = values->cols;

    if ( vi < 0 || vi >= var_count)
        CV_ERROR( CV_StsBadArg, "variable index is not correct" );

    assert( var_idx_mask );
    var_idx_mask->data.ptr[vi] = state;

    __END__;
}

/* End of file. */