/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.examples.serialization;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.examples.Example;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

public class ObjectStreamExample
implements Example {
    public static void main(String[] args) throws Exception {
        new ObjectStreamExample().run();
    }

    @Override
    public String command() {
        return "object_stream";
    }

    @Override
    public String description() {
        return "serialize a Random Cut Forest with object stream";
    }

    @Override
    public void run() throws Exception {
        int dimensions = 10;
        int numberOfTrees = 50;
        int sampleSize = 256;
        Precision precision = Precision.FLOAT_32;
        RandomCutForest forest = RandomCutForest.builder().compact(true).dimensions(dimensions).numberOfTrees(numberOfTrees).sampleSize(sampleSize).precision(precision).build();
        int dataSize = 1000 * sampleSize;
        NormalMixtureTestData testData = new NormalMixtureTestData();
        for (double[] point : testData.generateTestData(dataSize, dimensions)) {
            forest.update(point);
        }
        RandomCutForestMapper mapper = new RandomCutForestMapper();
        mapper.setSaveExecutorContextEnabled(true);
        System.out.printf("dimensions = %d, numberOfTrees = %d, sampleSize = %d, precision = %s%n", dimensions, numberOfTrees, sampleSize, precision);
        byte[] bytes = this.serialize(mapper.toState(forest));
        System.out.printf("Object output stream size = %d bytes%n", bytes.length);
        RandomCutForestState state2 = (RandomCutForestState)this.deserialize(bytes);
        RandomCutForest forest2 = mapper.toModel(state2);
        int testSize = 100;
        double delta = Math.log(sampleSize) / Math.log(2.0) * 0.05;
        int differences = 0;
        int anomalies = 0;
        for (double[] point : testData.generateTestData(testSize, dimensions)) {
            double score = forest.getAnomalyScore(point);
            double score2 = forest2.getAnomalyScore(point);
            if (score > 1.0 || score2 > 1.0) {
                ++anomalies;
                if (Math.abs(score - score2) > delta) {
                    ++differences;
                }
            }
            forest.update(point);
            forest2.update(point);
        }
        if (anomalies == 0) {
            throw new IllegalStateException("test data did not produce any anomalies");
        }
        if ((double)differences >= 0.01 * (double)testSize) {
            throw new IllegalStateException("restored forest does not agree with original forest");
        }
        System.out.println("Looks good!");
    }

    /*
     * Enabled aggressive exception aggregation
     */
    private byte[] serialize(Object model) {
        try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();){
            byte[] byArray;
            try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);){
                objectOutputStream.writeObject(model);
                objectOutputStream.flush();
                byArray = byteArrayOutputStream.toByteArray();
            }
            return byArray;
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to serialize model.", e.getCause());
        }
    }

    private Object deserialize(byte[] modelBin) {
        Object object;
        ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(modelBin));
        try {
            object = objectInputStream.readObject();
        }
        catch (Throwable throwable) {
            try {
                try {
                    objectInputStream.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException | ClassNotFoundException e) {
                throw new RuntimeException("Failed to deserialize model.", e.getCause());
            }
        }
        objectInputStream.close();
        return object;
    }
}

