# -*- coding: utf-8 -*-

"""
Convenience tools for random numbers...
"""

import sys
import os
import re
import struct
import numpy as np
import _hx_core 
from _hx_core import _tcl_interp, hx_project, hx_object_factory


def random_stratified_int_sample(min_value = 0, max_value=100, nb_intervals=1, nb_items=10):
    """draw a total of 'nb_items' random but unique integers in the interval [min_value, max_value], 
    or an equal number of samples are taken from each sub-intervals of ~equal size
    the returned list is randomly permuted so that strates does not appear in the output
     """
    #idea: take the list of all integers in all sub-intervals
    #perform a random permutation (using either numpy.random.permutation or numpy.random.shuffle ; and keep only the first elements of the vector.
    #this ensures no duplicate item with complexifying the algo.
    sample = np.zeros((nb_items), dtype=int)

    partition_bounds = np.round(np.linspace(min_value, max_value, nb_intervals+1))
    nb_items_per_interval = int(np.floor(nb_items/nb_intervals))
    div_remainder = np.remainder(nb_items, nb_intervals)
    pos = 0
    for k in range(nb_intervals):
        indices = np.random.permutation( np.arange(partition_bounds[k], partition_bounds[k+1]) )
        if k<div_remainder:
            sample[pos:pos+nb_items_per_interval+1] = indices[0:nb_items_per_interval+1]
            pos = pos+nb_items_per_interval+1
        else:
            sample[pos:pos+nb_items_per_interval] = indices[0:nb_items_per_interval]
            pos = pos+nb_items_per_interval
    
    return np.random.permutation(sample)
    
    
def random_stratified_flt_sample(min_value = 0, max_value=100, nb_intervals=1, nb_items=10):
    """draw a total of 'nb_items' random float value in the interval [min_value, max_value], 
    or an equal number of samples are taken from each sub-intervals of ~equal size
    the returned list is randomly permuted so that strates does not appear in the output
     """

    sample = np.zeros((nb_items), dtype=float)

    partition_bounds = np.linspace(min_value, max_value, nb_intervals+1)
    nb_items_per_interval = int(np.floor(nb_items/nb_intervals))
    div_remainder = np.remainder(nb_items, nb_intervals)
    pos = 0
    for k in range(nb_intervals):
        spread = partition_bounds[k+1]-partition_bounds[k]
        min = partition_bounds[k]
        if k<div_remainder:
            sample[pos:pos+nb_items_per_interval+1] = spread*np.random.random((nb_items_per_interval+1)) + min
            pos = pos+nb_items_per_interval+1
        else:
            sample[pos:pos+nb_items_per_interval] = spread*np.random.random((nb_items_per_interval)) + min
            pos = pos+nb_items_per_interval
    
    return np.random.permutation(sample)