DataScience/clustering/VarMin.java

211 lines
6.1 KiB
Java

import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
public class VarMin {
public ArrayList<ArrayList<Double>> data;
public ArrayList<ArrayList<Double>> centroids;
public ArrayList<Integer> clusterMap;
int[] count;
int k;
int dim;
public VarMin(int k){
this.k = k;
this.dim = 0;
}
//euklidische distanz
public double distance(ArrayList<Double> p1, ArrayList<Double> p2){
int count = Math.min(p1.size(), p2.size());
double sum = 0;
for(int i = 0; i < count;i++){
sum += Math.pow(p1.get(i) - p2.get(i), 2);
}
return Math.sqrt(sum);
}
//p1 += p2
public void addPoint(ArrayList<Double> p1, ArrayList<Double> p2){
int count = Math.min(p1.size(), p2.size());
for(int i = 0; i < count;i++){
p1.set(i, p1.get(i) + p2.get(i));
}
}
//p1 *= v
public void mulPoint(ArrayList<Double> p1, Double v){
for(int i = 0; i < p1.size();i++){
p1.set(i, p1.get(i) * v);
}
}
private void setParams(int k, int dim){
this.k = k;
this.dim = dim;
centroids = new ArrayList<ArrayList<Double>>();
for(int i = 0; i < k;i++){
centroids.add(new ArrayList<Double>());
for(int j = 0; j < dim;j++){
centroids.get(i).add(0.0);
}
}
clusterMap = new ArrayList<Integer>();
count = new int[k];
}
public void setData(ArrayList<ArrayList<Double>> data){
dim = data.get(0).size();
this.data = data;
setParams(k,dim);
for(int i = 0; i < data.size();i++){
clusterMap.add(0);
}
}
public int getClosestCluster(ArrayList<Double> p){
double minDist = Double.MAX_VALUE;
int minIndex = 0;
for(int c = 0; c < k;c++){
double dist = distance(p, centroids.get(c));
if(dist < minDist) {
minDist = dist;
minIndex = c;
}
}
return minIndex;
}
public void initCluster(){
int pointsPerCluster = data.size() / k;
for(int i = 0; i < data.size();i++){
int c = i / pointsPerCluster;
addPoint(centroids.get(c), data.get(i));
clusterMap.set(i, c);
}
for(int c = 0; c < k;c++){
mulPoint(centroids.get(c), 1.0 / pointsPerCluster);
count[c] = pointsPerCluster;
}
}
public void calcCluster(){
for(int c = 0; c < k;c++){
mulPoint(centroids.get(c), 0.0);
}
for(int c = 0; c < k;c++){
count[c] = 0;
}
for(int i = 0; i < data.size();i++){
int c = clusterMap.get(i);
addPoint(centroids.get(c), data.get(i));
count[c]++;
}
for(int c = 0; c < k;c++){
if(count[c] == 0){
mulPoint(centroids.get(c),0.0);
}else{
mulPoint(centroids.get(c), 1.0 / count[c]);
}
}
}
public void clustering(){
initCluster();
System.out.println(getCompactness());
//iteration
boolean change = true;
while(change){
change = false;
for(int i = 0; i < data.size();i++) {
int newCluster = getClosestCluster(data.get(i));
if(clusterMap.get(i) != newCluster){
change = true;
}
clusterMap.set(i, newCluster);
}
calcCluster();
System.out.println(getCompactness());
}
}
public double getCompactness(){
double sum = 0;
for(int i = 0; i < data.size();i++){
sum += distance(data.get(i), centroids.get(clusterMap.get(i)));
}
return sum;
}
public ArrayList<ArrayList<ArrayList<Double>>> getCluster(){
ArrayList<ArrayList<ArrayList<Double>>> cluster = new ArrayList<>();
for(int i = 0; i < k;i++){
cluster.add(new ArrayList<>());
}
for(int i = 0; i < data.size();i++) {
int c = clusterMap.get(i);
cluster.get(c).add(data.get(i));
}
return cluster;
}
public static void printCluster(ArrayList<ArrayList<ArrayList<Double>>> cluster){
int i = 0;
for(ArrayList<ArrayList<Double>> c : cluster){
System.out.print("Cluster" + ++i + ": ");
for(ArrayList<Double> p : c){
System.out.print("(");
for (int j = 0; j < p.size();j++) {
System.out.print(p.get(j));
if(j < p.size() - 1){
System.out.print(",");
}
}
System.out.print("); ");
}
System.out.println();
}
}
public static void writeToFile(ArrayList<ArrayList<ArrayList<Double>>> cluster, String file){
try(BufferedWriter stream = new BufferedWriter(new FileWriter(file))) {
for(ArrayList<ArrayList<Double>> c : cluster){
for(ArrayList<Double> p : c){
for (int j = 0; j < p.size();j++) {
stream.write(p.get(j).toString() + " ");
}
stream.write("\n");
}
stream.write("\n");
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
int k = 3;
if(args.length >= 1){
k = Integer.parseInt(args[0]);
}
VarMin varmin = new VarMin(k);
ArrayList<ArrayList<Double>> data = DataFileReader.readFile("input.txt");
varmin.setData(data);
varmin.clustering();
ArrayList<ArrayList<ArrayList<Double>>> cluster1 = varmin.getCluster();
System.out.println("Result:");
printCluster(cluster1);
writeToFile(cluster1, "cluster.txt");
}
}