From 07bd6f63156fb75c171c03341f0a200ec9c4472c Mon Sep 17 00:00:00 2001 From: Johannes Theiner Date: Tue, 14 Jan 2020 11:09:21 +0100 Subject: [PATCH] first commit --- clustering/DataFileReader.java | 40 +++++++ clustering/KMeans.java | 67 +++++++++++ clustering/VarMin.java | 210 +++++++++++++++++++++++++++++++++ clustering/show.py | 28 +++++ 4 files changed, 345 insertions(+) create mode 100644 clustering/DataFileReader.java create mode 100644 clustering/KMeans.java create mode 100644 clustering/VarMin.java create mode 100644 clustering/show.py diff --git a/clustering/DataFileReader.java b/clustering/DataFileReader.java new file mode 100644 index 0000000..d787f3d --- /dev/null +++ b/clustering/DataFileReader.java @@ -0,0 +1,40 @@ +import java.io.*; +import java.util.ArrayList; + +public class DataFileReader { + static ArrayList> readFile(String file){ + ArrayList> data = new ArrayList>(); + //Count: leere Zeilen werden nicht mitverwendet + int count = 0; + try(BufferedReader stream = new BufferedReader(new FileReader(file))){ + while(true){ + String line = stream.readLine(); + if(line == null){ + break; + } + if(!line.contains("#")){ + String[] parts = line.split("[ |\t]"); + data.add(new ArrayList()); + for(String s : parts){ + try{ + data.get(count).add(Double.parseDouble(s)); + }catch(NumberFormatException e){ + } + } + if(data.get(count).size() == 0){ + data.remove(count); + }else{ + count++; + } + } + } + } + catch(FileNotFoundException e){ + e.printStackTrace(); + System.out.println("file " + file + " not found"); + } catch (IOException e) { + e.printStackTrace(); + } + return data; + } +} diff --git a/clustering/KMeans.java b/clustering/KMeans.java new file mode 100644 index 0000000..f58230d --- /dev/null +++ b/clustering/KMeans.java @@ -0,0 +1,67 @@ +import java.util.ArrayList; + +public class KMeans extends VarMin{ + public KMeans(int k){ + super(k); + } + + @Override + public void clustering(){ + initCluster(); + System.out.println(getCompactness()); + boolean change = true; + while(change){ + change = false; + for(int i = 0; i < data.size();i++) { + int newCluster = getClosestCluster(data.get(i)); + int oldCluster = clusterMap.get(i); + if(oldCluster != newCluster){ + clusterMap.set(i, newCluster); + + //update centroids + ArrayList p = new ArrayList<>(data.get(i)); + ArrayList u1 = new ArrayList<>(centroids.get(oldCluster)); + ArrayList u2 = new ArrayList<>(centroids.get(newCluster)); + mulPoint(u1, (double)count[oldCluster]); + mulPoint(u2, (double)count[newCluster]); + + addPoint(u2, p); + mulPoint(p, -1.0); + addPoint(u1, p); + + mulPoint(u1,(1.0 / (count[oldCluster] - 1))); + mulPoint(u2,(1.0 / (count[newCluster] + 1))); + centroids.set(oldCluster, u1); + centroids.set(newCluster, u2); + + count[oldCluster]--; + count[newCluster]++; + + //centroids.get(oldCluster) = (1.0 / (count[oldCluster] - 1)) * (count[oldCluster] * centroids.get(oldCluster) - data.get(i)); + //centroids.get(newCluster) = (1.0 / (count[newCluster] + 1)) * (count[newCluster] * centroids.get(newCluster) + data.get(i)); + + //calcCluster(); + change = true; + } + } + System.out.println(getCompactness()); + } + } + + public static void main(String[] args) { + int k = 3; + if(args.length >= 1){ + k = Integer.parseInt(args[0]); + } + KMeans kmeans = new KMeans(k); + ArrayList> data = DataFileReader.readFile("input.txt"); + kmeans.setData(data); + + kmeans.clustering(); + ArrayList>> cluster1 = kmeans.getCluster(); + System.out.println("Result:"); + printCluster(cluster1); + + writeToFile(cluster1, "cluster.txt"); + } +} diff --git a/clustering/VarMin.java b/clustering/VarMin.java new file mode 100644 index 0000000..bed6cab --- /dev/null +++ b/clustering/VarMin.java @@ -0,0 +1,210 @@ +import java.io.BufferedWriter; +import java.io.FileNotFoundException; +import java.io.FileWriter; +import java.io.IOException; +import java.lang.reflect.Array; +import java.util.ArrayList; + +public class VarMin { + public ArrayList> data; + public ArrayList> centroids; + public ArrayList clusterMap; + int[] count; + int k; + int dim; + + public VarMin(int k){ + this.k = k; + this.dim = 0; + } + + //euklidische distanz + public double distance(ArrayList p1, ArrayList p2){ + int count = Math.min(p1.size(), p2.size()); + double sum = 0; + for(int i = 0; i < count;i++){ + sum += Math.pow(p1.get(i) - p2.get(i), 2); + } + return Math.sqrt(sum); + } + + //p1 += p2 + public void addPoint(ArrayList p1, ArrayList p2){ + int count = Math.min(p1.size(), p2.size()); + for(int i = 0; i < count;i++){ + p1.set(i, p1.get(i) + p2.get(i)); + } + } + + //p1 *= v + public void mulPoint(ArrayList p1, Double v){ + for(int i = 0; i < p1.size();i++){ + p1.set(i, p1.get(i) * v); + } + } + + private void setParams(int k, int dim){ + this.k = k; + this.dim = dim; + centroids = new ArrayList>(); + for(int i = 0; i < k;i++){ + centroids.add(new ArrayList()); + for(int j = 0; j < dim;j++){ + centroids.get(i).add(0.0); + } + } + clusterMap = new ArrayList(); + count = new int[k]; + } + + public void setData(ArrayList> data){ + dim = data.get(0).size(); + this.data = data; + setParams(k,dim); + for(int i = 0; i < data.size();i++){ + clusterMap.add(0); + } + } + + public int getClosestCluster(ArrayList p){ + double minDist = Double.MAX_VALUE; + int minIndex = 0; + for(int c = 0; c < k;c++){ + double dist = distance(p, centroids.get(c)); + if(dist < minDist) { + minDist = dist; + minIndex = c; + } + } + return minIndex; + } + + public void initCluster(){ + int pointsPerCluster = data.size() / k; + for(int i = 0; i < data.size();i++){ + int c = i / pointsPerCluster; + addPoint(centroids.get(c), data.get(i)); + clusterMap.set(i, c); + } + for(int c = 0; c < k;c++){ + mulPoint(centroids.get(c), 1.0 / pointsPerCluster); + count[c] = pointsPerCluster; + } + } + + public void calcCluster(){ + for(int c = 0; c < k;c++){ + mulPoint(centroids.get(c), 0.0); + } + for(int c = 0; c < k;c++){ + count[c] = 0; + } + for(int i = 0; i < data.size();i++){ + int c = clusterMap.get(i); + addPoint(centroids.get(c), data.get(i)); + count[c]++; + } + for(int c = 0; c < k;c++){ + if(count[c] == 0){ + mulPoint(centroids.get(c),0.0); + }else{ + mulPoint(centroids.get(c), 1.0 / count[c]); + } + } + } + + public void clustering(){ + initCluster(); + + System.out.println(getCompactness()); + //iteration + boolean change = true; + while(change){ + change = false; + for(int i = 0; i < data.size();i++) { + int newCluster = getClosestCluster(data.get(i)); + if(clusterMap.get(i) != newCluster){ + change = true; + } + clusterMap.set(i, newCluster); + } + calcCluster(); + System.out.println(getCompactness()); + } + } + + public double getCompactness(){ + double sum = 0; + for(int i = 0; i < data.size();i++){ + sum += distance(data.get(i), centroids.get(clusterMap.get(i))); + } + return sum; + } + + public ArrayList>> getCluster(){ + ArrayList>> cluster = new ArrayList<>(); + for(int i = 0; i < k;i++){ + cluster.add(new ArrayList<>()); + } + for(int i = 0; i < data.size();i++) { + int c = clusterMap.get(i); + cluster.get(c).add(data.get(i)); + } + return cluster; + } + + public static void printCluster(ArrayList>> cluster){ + int i = 0; + for(ArrayList> c : cluster){ + System.out.print("Cluster" + ++i + ": "); + for(ArrayList p : c){ + System.out.print("("); + for (int j = 0; j < p.size();j++) { + System.out.print(p.get(j)); + if(j < p.size() - 1){ + System.out.print(","); + } + } + System.out.print("); "); + } + System.out.println(); + } + } + + public static void writeToFile(ArrayList>> cluster, String file){ + try(BufferedWriter stream = new BufferedWriter(new FileWriter(file))) { + + for(ArrayList> c : cluster){ + for(ArrayList p : c){ + for (int j = 0; j < p.size();j++) { + stream.write(p.get(j).toString() + " "); + } + stream.write("\n"); + } + stream.write("\n"); + } + + } catch (FileNotFoundException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public static void main(String[] args) { + int k = 3; + if(args.length >= 1){ + k = Integer.parseInt(args[0]); + } + VarMin varmin = new VarMin(k); + ArrayList> data = DataFileReader.readFile("input.txt"); + varmin.setData(data); + + varmin.clustering(); + ArrayList>> cluster1 = varmin.getCluster(); + System.out.println("Result:"); + printCluster(cluster1); + + writeToFile(cluster1, "cluster.txt"); + } +} diff --git a/clustering/show.py b/clustering/show.py new file mode 100644 index 0000000..f1a98c0 --- /dev/null +++ b/clustering/show.py @@ -0,0 +1,28 @@ +import matplotlib.pyplot as plt +import numpy as np + +a = open("cluster.txt").read() +a = a.split("\n\n") + +data = [] +for b in a: + cluster = [] + for c in b.split("\n"): + d = c.split(" ") + vec = [] + for e in d: + if e != "": + vec.append(float(e)) + if len(vec) != 0: + cluster.append(vec) + if len(cluster) != 0: + data.append(cluster) + +for c in data: + xs = [x[0] for x in c] + ys = [x[1] for x in c] + plt.subplot() + plt.plot(xs, ys, "o") +plt.show() + +