summaryrefslogtreecommitdiff
path: root/enigma-swing/src/main/java/cuchaz/enigma/gui/search/SearchUtil.java
blob: c8212ce5f00038689c50d725439fdcbb741b38a3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
package cuchaz.enigma.gui.search;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import cuchaz.enigma.utils.Pair;

public class SearchUtil<T extends SearchEntry> {
	private final Map<T, Entry<T>> entries = new HashMap<>();
	private final Map<String, Integer> hitCount = new HashMap<>();
	private final Executor searchExecutor = Executors.newWorkStealingPool();

	public void add(T entry) {
		Entry<T> e = Entry.from(entry);
		entries.put(entry, e);
	}

	public void add(Entry<T> entry) {
		entries.put(entry.searchEntry, entry);
	}

	public void addAll(Collection<T> entries) {
		this.entries.putAll(entries.parallelStream().collect(Collectors.toMap(e -> e, Entry::from)));
	}

	public void remove(T entry) {
		entries.remove(entry);
	}

	public void clear() {
		entries.clear();
	}

	public void clearHits() {
		hitCount.clear();
	}

	public Stream<T> search(String term) {
		return entries.values().parallelStream().map(e -> new Pair<>(e, e.getScore(term, hitCount.getOrDefault(e.searchEntry.getIdentifier(), 0)))).filter(e -> e.b > 0).sorted(Comparator.comparingDouble(o -> -o.b)).map(e -> e.a.searchEntry).sequential();
	}

	public SearchControl asyncSearch(String term, SearchResultConsumer<T> consumer) {
		Map<String, Integer> hitCount = new HashMap<>(this.hitCount);
		Map<T, Entry<T>> entries = new HashMap<>(this.entries);
		float[] scores = new float[entries.size()];
		Lock scoresLock = new ReentrantLock();
		AtomicInteger size = new AtomicInteger();
		AtomicBoolean control = new AtomicBoolean(false);
		AtomicInteger elapsed = new AtomicInteger();

		for (Entry<T> value : entries.values()) {
			searchExecutor.execute(() -> {
				try {
					if (control.get()) {
						return;
					}

					float score = value.getScore(term, hitCount.getOrDefault(value.searchEntry.getIdentifier(), 0));

					if (score <= 0) {
						return;
					}

					score = -score; // sort descending

					try {
						scoresLock.lock();

						if (control.get()) {
							return;
						}

						int dataSize = size.getAndIncrement();
						int index = Arrays.binarySearch(scores, 0, dataSize, score);

						if (index < 0) {
							index = ~index;
						}

						System.arraycopy(scores, index, scores, index + 1, dataSize - index);
						scores[index] = score;
						consumer.add(index, value.searchEntry);
					} finally {
						scoresLock.unlock();
					}
				} finally {
					elapsed.incrementAndGet();
				}
			});
		}

		return new SearchControl() {
			@Override
			public void stop() {
				control.set(true);
			}

			@Override
			public boolean isFinished() {
				return entries.size() == elapsed.get();
			}

			@Override
			public float getProgress() {
				return (float) elapsed.get() / entries.size();
			}
		};
	}

	public void hit(T entry) {
		if (entries.containsKey(entry)) {
			hitCount.compute(entry.getIdentifier(), (_id, i) -> i == null ? 1 : i + 1);
		}
	}

	public static final class Entry<T extends SearchEntry> {
		public final T searchEntry;
		private final String[][] components;

		private Entry(T searchEntry, String[][] components) {
			this.searchEntry = searchEntry;
			this.components = components;
		}

		public float getScore(String term, int hits) {
			String ucTerm = term.toUpperCase(Locale.ROOT);
			float maxScore = (float) Arrays.stream(components).mapToDouble(name -> getScoreFor(ucTerm, name)).max().orElse(0.0);
			return maxScore * (hits + 1);
		}

		/**
		 * Computes the score for the given <code>name</code> against the given search term.
		 *
		 * @param term the search term (expected to be upper-case)
		 * @param name the entry name, split at word boundaries (see {@link Entry#wordwiseSplit(String)})
		 * @return the computed score for the entry
		 */
		private static float getScoreFor(String term, String[] name) {
			int totalLength = Arrays.stream(name).mapToInt(String::length).sum();
			float scorePerChar = 1f / totalLength;

			// This map contains a snapshot of all the states the search has
			// been in. The keys are the remaining characters of the search
			// term, the values are the maximum scores for that remaining
			// search term part.
			Map<String, Float> snapshots = new HashMap<>();
			snapshots.put(term, 0f);

			// For each component, start at each existing snapshot, searching
			// for the next longest match, and calculate the new score for each
			// match length until the maximum. Then the new scores are put back
			// into the snapshot map.
			for (int componentIndex = 0; componentIndex < name.length; componentIndex++) {
				String component = name[componentIndex];
				float posMultiplier = (name.length - componentIndex) * 0.3f;
				Map<String, Float> newSnapshots = new HashMap<>();

				for (Map.Entry<String, Float> snapshot : snapshots.entrySet()) {
					String remaining = snapshot.getKey();
					float score = snapshot.getValue();
					component = component.toUpperCase(Locale.ROOT);
					int l = compareEqualLength(remaining, component);

					for (int i = 1; i <= l; i++) {
						float baseScore = scorePerChar * i;
						float chainBonus = (i - 1) * 0.5f;
						merge(newSnapshots, Collections.singletonMap(remaining.substring(i), score + baseScore * posMultiplier + chainBonus), Math::max);
					}
				}

				merge(snapshots, newSnapshots, Math::max);
			}

			// Only return the score for when the search term was completely
			// consumed.
			return snapshots.getOrDefault("", 0f);
		}

		private static <K, V> void merge(Map<K, V> self, Map<K, V> source, BiFunction<V, V, V> combiner) {
			source.forEach((k, v) -> self.compute(k, (_k, v1) -> v1 == null ? v : v == null ? v1 : combiner.apply(v, v1)));
		}

		public static <T extends SearchEntry> Entry<T> from(T e) {
			String[][] components = e.getSearchableNames().parallelStream().map(Entry::wordwiseSplit).toArray(String[][]::new);
			return new Entry<>(e, components);
		}

		private static int compareEqualLength(String s1, String s2) {
			int len = 0;

			while (len < s1.length() && len < s2.length() && s1.charAt(len) == s2.charAt(len)) {
				len += 1;
			}

			return len;
		}

		/**
		 * Splits the given input into components, trying to detect word parts.
		 *
		 * <p>Example of how words get split (using <code>|</code> as seperator):
		 * <p><code>MinecraftClientGame -> Minecraft|Client|Game</code></p>
		 * <p><code>HTTPInputStream -> HTTP|Input|Stream</code></p>
		 * <p><code>class_932 -> class|_|932</code></p>
		 * <p><code>X11FontManager -> X|11|Font|Manager</code></p>
		 * <p><code>openHTTPConnection -> open|HTTP|Connection</code></p>
		 * <p><code>open_http_connection -> open|_|http|_|connection</code></p>
		 *
		 * @param input the input to split
		 * @return the resulting components
		 */
		private static String[] wordwiseSplit(String input) {
			List<String> list = new ArrayList<>();

			while (!input.isEmpty()) {
				int take;

				if (Character.isLetter(input.charAt(0))) {
					if (input.length() == 1) {
						take = 1;
					} else {
						boolean nextSegmentIsUppercase = Character.isUpperCase(input.charAt(0)) && Character.isUpperCase(input.charAt(1));

						if (nextSegmentIsUppercase) {
							int nextLowercase = 1;

							while (Character.isUpperCase(input.charAt(nextLowercase))) {
								nextLowercase += 1;

								if (nextLowercase == input.length()) {
									nextLowercase += 1;
									break;
								}
							}

							take = nextLowercase - 1;
						} else {
							int nextUppercase = 1;

							while (nextUppercase < input.length() && Character.isLowerCase(input.charAt(nextUppercase))) {
								nextUppercase += 1;
							}

							take = nextUppercase;
						}
					}
				} else if (Character.isDigit(input.charAt(0))) {
					int nextNonNum = 1;

					while (nextNonNum < input.length() && Character.isLetter(input.charAt(nextNonNum)) && !Character.isLowerCase(input.charAt(nextNonNum))) {
						nextNonNum += 1;
					}

					take = nextNonNum;
				} else {
					take = 1;
				}

				list.add(input.substring(0, take));
				input = input.substring(take);
			}

			return list.toArray(new String[0]);
		}
	}

	@FunctionalInterface
	public interface SearchResultConsumer<T extends SearchEntry> {
		void add(int index, T entry);
	}

	public interface SearchControl {
		void stop();

		boolean isFinished();

		float getProgress();
	}
}