import numpy as np

# These functions may be customized to specific needs, such as accommodating to
# the expectations of Neural Networks trained outside our environment
def get_model_information():
    """Model information for end user

    Returns
    -------
        str Information message that will appear in the Deep-Learning
        models information port.
    """

    model_information = "<html>This network is a back scattered electrons denoiser<br>- The model is based on a <a href=\"https://arxiv.org/abs/1505.04597\">U-Net architecture</a><br>- It is working on grayscale image only (1 channel)</html>"
    return model_information

# This function pre-processes the data by standardizing it to zero mean and unit variance.
def preprocess_input(x):
    """Standardization of the input array x, prior to the actual prediction by the Deep Learning Prediction module.
    This standardization must be the same as that which was applied during the Training of the model.


    Parameters
    ----------
        x: Array to pre-process.
            Numpy array of shape (depth, width, height, channels)

    Returns
    -------
        processed array.
            Expected shape: (depth, width, height, channels)
    """
    if len(x.shape) != 4:
        raise ValueError("incorrect 'x' shape, must be a 4d array with shape=(depth, width, height, channels)")

    n, _, _, _ = x.shape

    if isinstance(x, np.ndarray):
        return ((x-x.min())*float(np.iinfo(np.uint8).max)/(x.max()-x.min())).astype(dtype=np.float32)
    else:
        raise TypeError("'x' must be a numpy array")


# the denoising model was originally trained to generate RGB outputs. For grayscale images, all channels are the same and we only keep one.
def postprocess_local_output(predicted_tile):
    """This function performs a local post-processing of the prediction output,
    operating on each tile generated when the input data is large.
    It is intended for pixel-wise operations, which do not require a global context.

    Parameters
    ----------
        predicted_tile: tile of the model prediction output
            Numpy array of shape (tile_depth, tile_width, tile_height, input_channels)

    Returns
    -------
        post-processed tile
            Expected shape: (tile_depth, tile_width, tile_height, output_channels)
    """
    
    if len(predicted_tile.shape) != 4:
        raise ValueError("incorrect 'predicted_tile' shape, must be a 4d array with shape=(tile_depth, tile_width, tile_height, input_channels)")
    return predicted_tile[:, :, :, 0:1]
    
    

# Two steps process output array after model prediction
# 1/ standardization to the same mean and std as input image
# 2/ casting to the same pixel type as the input
def postprocess_global_output(list_predicted_arrays, input_array):
    """This function performs a global post-processing of the prediction output,
    after the tiles have been re-assembled. It is intended to be used for global
    normalization operations.


    Parameters
    ----------
        list_predicted_arrays: list of model prediction output(s), each item of the list is a channel.
            List of numpy arrays, each with shape (depth, width, height, 1)

        input_array: input data array.
            Numpy array of shape (depth, input_width, input_height, channels)

    Returns
    -------
        A list of the processed arrays, each will be exposed as a 3D Data object in the application.
            Expected shape is a list of arrays, each with shape (depth, output_width, output_height, 1)
    """

    pred = list_predicted_arrays[0]
    if isinstance(pred, np.ndarray) and isinstance(input_array, np.ndarray):
        n, _, _, _ = pred.shape        
        voxel_type = input_array.dtype.type
        
        # get mean and stddev of each slice of the input
        tmp_input = np.reshape(input_array,[input_array.shape[0],np.product(input_array.shape[1:])])
        target_mean = np.mean(tmp_input,axis=1).reshape((n,1,1,1))
        target_std = np.std(tmp_input,axis=1).reshape((n,1,1,1))

        # get mean and stddev of each predicted slice
        tmp_pred = np.reshape(pred,[pred.shape[0],np.product(pred.shape[1:])])
        pred_mean = np.mean(tmp_pred,axis=1).reshape((n,1,1,1))
        pred_std = np.std(tmp_pred,axis=1).reshape((n,1,1,1))

        # Mean std normalization slice by slice
        pred_norm = (pred-pred_mean)*target_std/np.maximum(pred_std,1e-7) + target_mean

        # Clip to allowed range
        
        pred_norm = np.clip(pred_norm, 0, np.iinfo(voxel_type).max)
        
        # Cast and return result within a list as expected
        return [pred_norm.astype(dtype=voxel_type)]
    else:
        return False


# 
def postprocess_output_type():
    """This function allows specifying the output object type of the Deep Learning Prediction module.

    Returns
    -------
        List of strings describing the object type for the corresponding data exposed in the application.
            The string should be either HxUniformLabelField3, or HxUniformScalarField3.
    """
    return ['HxUniformScalarField3']
