summaryrefslogtreecommitdiff
path: root/src/main/java/cuchaz/enigma/analysis/JarClassIterator.java
blob: 040042708f7d703abaec83ae3e873b851f5616ca (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*******************************************************************************
 * Copyright (c) 2015 Jeff Martin.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Lesser General Public
 * License v3.0 which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/lgpl.html
 * <p>
 * Contributors:
 * Jeff Martin - initial API and implementation
 ******************************************************************************/
package cuchaz.enigma.analysis;

import com.google.common.collect.Lists;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

import cuchaz.enigma.Constants;
import cuchaz.enigma.mapping.ClassEntry;
import javassist.ByteArrayClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.NotFoundException;
import javassist.bytecode.Descriptor;

public class JarClassIterator implements Iterator<CtClass> {

    private JarFile jar;
    private Iterator<JarEntry> iter;

    public JarClassIterator(JarFile jar) {
        this.jar = jar;

        // get the jar entries that correspond to classes
        List<JarEntry> classEntries = Lists.newArrayList();
        Enumeration<JarEntry> entries = this.jar.entries();
        while (entries.hasMoreElements()) {
            JarEntry entry = entries.nextElement();

            // is this a class file?
            if (entry.getName().endsWith(".class")) {
                classEntries.add(entry);
            }
        }
        this.iter = classEntries.iterator();
    }

    @Override
    public boolean hasNext() {
        return this.iter.hasNext();
    }

    @Override
    public CtClass next() {
        JarEntry entry = this.iter.next();
        try {
            return getClass(this.jar, entry);
        } catch (IOException | NotFoundException ex) {
            throw new Error("Unable to load class: " + entry.getName());
        }
    }

    @Override
    public void remove() {
        throw new UnsupportedOperationException();
    }

    public static List<ClassEntry> getClassEntries(JarFile jar) {
        List<ClassEntry> classEntries = Lists.newArrayList();
        Enumeration<JarEntry> entries = jar.entries();
        while (entries.hasMoreElements()) {
            JarEntry entry = entries.nextElement();

            // is this a class file?
            if (!entry.isDirectory() && entry.getName().endsWith(".class")) {
                classEntries.add(getClassEntry(entry));
            }
        }
        return classEntries;
    }

    public static Iterable<CtClass> classes(final JarFile jar) {
        return () -> new JarClassIterator(jar);
    }

    public static CtClass getClass(JarFile jar, ClassEntry classEntry) {
        try {
            return getClass(jar, new JarEntry(classEntry.getName() + ".class"));
        } catch (IOException | NotFoundException ex) {
            throw new Error("Unable to load class: " + classEntry.getName());
        }
    }

    private static CtClass getClass(JarFile jar, JarEntry entry) throws IOException, NotFoundException {
        // read the class into a buffer
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte[] buf = new byte[Constants.KiB];
        int totalNumBytesRead = 0;
        InputStream in = jar.getInputStream(entry);
        while (in.available() > 0) {
            int numBytesRead = in.read(buf);
            if (numBytesRead < 0) {
                break;
            }
            bos.write(buf, 0, numBytesRead);

            // sanity checking
            totalNumBytesRead += numBytesRead;
            if (totalNumBytesRead > Constants.MiB) {
                throw new Error("Class file " + entry.getName() + " larger than 1 MiB! Something is wrong!");
            }
        }

        // get a javassist handle for the class
        String className = Descriptor.toJavaName(getClassEntry(entry).getName());
        ClassPool classPool = new ClassPool();
        classPool.appendSystemPath();
        classPool.insertClassPath(new ByteArrayClassPath(className, bos.toByteArray()));
        return classPool.get(className);
    }

    private static ClassEntry getClassEntry(JarEntry entry) {
        return new ClassEntry(entry.getName().substring(0, entry.getName().length() - ".class".length()));
    }
}