Find a median of N^2 numbers having memory for N of them
I was trying to learn about distributed computing and came across a problem of finding median of a large set of numbers:
Assume that we have a large set of numbers (lets say number of elements is N*K) that cannot fit into memory (size N). How do we find the median of this data? Assume that the operations performed on the memory are independent ie we can consider that there are K machines each that can process at most N elements.
I thought that median of medians can be used for this purpose. We can load N numbers at a time into memory. We find the median of that set in O(logN)
time and save it.
Then we save all these K medians and find out the median of medians. Again O(logK)
, so far the complexity has been O(K*logN + logK)
.
But this median of medians is just an approximate median. I think it will be optimal to use it as a pivot to get a best case performance, but for that we will need to fit all the N*K numbers in memory.
How can we find the actual median of the set now that we have a good approximate pivot?
Why don't you build a histogram? Ie the number of cases (values) that fall into each of several categories. The categories should be a consecutive, non-overlapping intervals of a variable.
With this histogram you can make a first estimation of the median (ie, median is between [a,b]), and know how many values fall into this interval (H). If H<=N, read the numbers again, ignoring these outside this interval, and moving to RAM the numbers within the interval. Find the median.
If H>N, do a new partition of the interval and repeat the procedure. It shouldn't take more than 2 or 3 iterations.
Note that for each partition you only need to store a, b, a Delta and the array with the number of values that fall into each subinterval.
EDIT. It turnet out to be a bit more complicated that I expected. In each iteration after estimating the interval the median falls into, we should also consider "how much" histogram we leave on the right and on the left of this interval. I changed the stop condition too. Anyway, I did a C++ implementation.
#include <iostream>
#include <algorithm>
#include <time.h>
#include <stdlib.h>
//This is N^2... or just the number of values in your array,
//note that we never modify it except at the end (just for sorting
//and testing purposes).
#define N2 1000000
//Number of elements in the histogram. Must be >2
#define HISTN 1000
double findmedian (double *values, double min, double max);
int getindex (int *hist);
void put (int *hist, double min, double max, double val, double delta);
int main ()
{
//Set max and min to the max/min values your array variables can hold,
//calculate it, or maybe we know that they are bounded
double max=1000.0;
double min=0.0;
double delta;
double values[N2];
int hist[HISTN];
int ind;
double median;
int iter=0;
//Initialize with random values
srand ((unsigned) (time(0)));
for (int i=0; i<N2; ++i)
values[i]=((double)rand()/(double)RAND_MAX);
double imin=min;
double imax=max;
clock_t begin=clock();
while (1) {
iter++;
for (int i=0; i<HISTN; ++i)
hist[i]=0;
delta=(imax-imin)/HISTN;
for (int j=0; j<N2; ++j)
put (hist, imin, imax, values[j], delta);
ind=getindex (hist);
imax=imin;
imin=imin+delta*ind;
imax=imax+delta*(ind+1);
if (hist[ind]==1 || imax-imin<=DBL_MIN) {
median=findmedian (values, imin, imax);
break;
}
}
clock_t end=clock();
std::cout << "Median with our algorithm: " << median << " - " << iter << "iterations of the algorithm" << std::endl;
double time=(double)(end-begin)/CLOCKS_PER_SEC;
std::cout << "Time: " << time << std::endl;
//Let's compare our result with the median calculated after sorting the
//array
//Should be values[(int)N2/2] if N2 is odd
begin=clock();
std::sort (values, values+N2);
std::cout << "Median after sorting: " << values[(int)N2/2-1] << std::endl;
end=clock();
time=(double)(end-begin)/CLOCKS_PER_SEC;
std::cout << "Time: " << time << std::endl;
return 0;
}
double findmedian (double *values, double min, double max) {
for (int i=0; i<N2; ++i)
if (values[i]>=min && values[i]<=max)
return values[i];
return 0;
}
int getindex (int *hist)
{
static int pd=0;
int left=0;
int right=0;
int i;
for (int k=0; k<HISTN; k++)
right+=hist[k];
for (i=0; i<HISTN; i++) {
right-=hist[i];
if (i>0)
left+=hist[i-1];
if (hist[i]>0) {
if (pd+right-left<=hist[i]) {
pd=pd+right-left;
break;
}
}
}
return i;
}
void put (int *hist, double min, double max, double val, double delta)
{
int pos;
if (val<min || val>max)
return;
pos=(val-min)/delta;
hist[pos]++;
return;
}
I also included a naive calculation of the median (sorting) in order to compare with the results of the algorithm. 4 or 5 iterations are enough. It means we just need to read the set from network or HDD 4-5 times.
Some results:
N2=10000
HISTN=100
Median with our algorithm: 0.497143 - 4 iterations of the algorithm
Time: 0.000787
Median after sorting: 0.497143
Time: 0.001626
(Algorithm is 2 times faster)
N2=1000000
HISTN=1000
Median with our algorithm: 0.500665 - 4 iterations of the algorithm
Time: 0.028874
Median after sorting: 0.500665
Time: 0.097498
(Algorithm is ~3 times faster)
If you want to parallelize the algorithm, each machine can have N elements and calculate the histogram. Once it is calculated, they would send it to the master machine, that would sum all the histograms (easy, it can be really small... the algorithm even works with histograms of 2 intervals). Then it would send new instructions (ie the new interval) to the slave machines in order to calculate new histograms. Note that each machine does not need to have any knowledge about the N elements the other machines own.
Take a random sample of N of them. With constant probability dependent on c, the median of this random sample is within c*N places of the median. If you do this twice, then, with constant probability, you've narrowed the possible positions of the median down to linearly many. Do whatever horrible thing you like to select the element of the appropriate rank.
If you assume that your numbers are B
bit binary integers (floating point is fine too because you can sort based on sign and then based on the exponent and then based on the mantissa) then you can solve the problem in O(N^2 B / K)
time if you have K
processors and N^2
numbers. You basically do binary search: Start with a pivot equal to the middle of the range, and use your K
processors to count how many numbers are less than and equal to and greater than the pivot. Then you'll know whether the median is equal to the pivot or greater than or less than the pivot. Continue with the binary search. Each binary search step takes O(N^2 /K)
time to go through the list of numbers, giving O(N^2 B / K)
overall running time.
上一篇: 递归完成后要做什么