/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.state;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.ComponentList;
import com.amazon.randomcutforest.IComponentModel;
import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.executor.PointStoreCoordinator;
import com.amazon.randomcutforest.executor.SamplerPlusTree;
import com.amazon.randomcutforest.sampler.CompactSampler;
import com.amazon.randomcutforest.sampler.IStreamSampler;
import com.amazon.randomcutforest.sampler.Weighted;
import com.amazon.randomcutforest.state.ExecutionContext;
import com.amazon.randomcutforest.state.IContextualStateMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.amazon.randomcutforest.state.sampler.CompactSamplerMapper;
import com.amazon.randomcutforest.state.sampler.CompactSamplerState;
import com.amazon.randomcutforest.state.store.PointStoreMapper;
import com.amazon.randomcutforest.state.store.PointStoreState;
import com.amazon.randomcutforest.state.tree.CompactRandomCutTreeContext;
import com.amazon.randomcutforest.state.tree.CompactRandomCutTreeState;
import com.amazon.randomcutforest.state.tree.RandomCutTreeMapper;
import com.amazon.randomcutforest.store.IPointStore;
import com.amazon.randomcutforest.store.PointStore;
import com.amazon.randomcutforest.tree.ITree;
import com.amazon.randomcutforest.tree.RandomCutTree;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import lombok.Generated;

public class RandomCutForestMapper
implements IContextualStateMapper<RandomCutForest, RandomCutForestState, ExecutionContext> {
    private boolean saveTreeStateEnabled = false;
    private boolean saveCoordinatorStateEnabled = true;
    private boolean saveSamplerStateEnabled = true;
    private boolean saveExecutorContextEnabled = false;
    private boolean compressionEnabled = true;
    private boolean partialTreeStateEnabled = false;

    @Override
    public RandomCutForestState toState(RandomCutForest forest) {
        if (this.saveTreeStateEnabled) {
            CommonUtils.checkArgument(forest.isCompact(), "tree state cannot be saved for noncompact forests");
        }
        RandomCutForestState state = new RandomCutForestState();
        state.setNumberOfTrees(forest.getNumberOfTrees());
        state.setDimensions(forest.getDimensions());
        state.setTimeDecay(forest.getTimeDecay());
        state.setSampleSize(forest.getSampleSize());
        state.setShingleSize(forest.getShingleSize());
        state.setCenterOfMassEnabled(forest.isCenterOfMassEnabled());
        state.setOutputAfter(forest.getOutputAfter());
        state.setStoreSequenceIndexesEnabled(forest.isStoreSequenceIndexesEnabled());
        state.setTotalUpdates(forest.getTotalUpdates());
        state.setCompact(forest.isCompact());
        state.setInternalShinglingEnabled(forest.isInternalShinglingEnabled());
        state.setBoundingBoxCacheFraction(forest.getBoundingBoxCacheFraction());
        state.setSaveSamplerStateEnabled(this.saveSamplerStateEnabled);
        state.setSaveTreeStateEnabled(this.saveTreeStateEnabled);
        state.setSaveCoordinatorStateEnabled(this.saveCoordinatorStateEnabled);
        state.setPrecision(forest.getPrecision().name());
        state.setCompressed(this.compressionEnabled);
        state.setPartialTreeState(this.partialTreeStateEnabled);
        if (this.saveExecutorContextEnabled) {
            ExecutionContext executionContext = new ExecutionContext();
            executionContext.setParallelExecutionEnabled(forest.isParallelExecutionEnabled());
            executionContext.setThreadPoolSize(forest.getThreadPoolSize());
            state.setExecutionContext(executionContext);
        }
        if (this.saveCoordinatorStateEnabled) {
            PointStoreCoordinator pointStoreCoordinator = (PointStoreCoordinator)forest.getUpdateCoordinator();
            PointStoreMapper mapper = new PointStoreMapper();
            mapper.setCompressionEnabled(this.compressionEnabled);
            mapper.setNumberOfTrees(forest.getNumberOfTrees());
            PointStoreState pointStoreState = mapper.toState((PointStore)pointStoreCoordinator.getStore());
            state.setPointStoreState(pointStoreState);
        }
        ArrayList<CompactSamplerState> samplerStates = null;
        if (this.saveSamplerStateEnabled) {
            samplerStates = new ArrayList<CompactSamplerState>();
        }
        ArrayList trees = null;
        if (this.saveTreeStateEnabled) {
            trees = new ArrayList();
        }
        CompactSamplerMapper samplerMapper = new CompactSamplerMapper();
        samplerMapper.setCompressionEnabled(this.compressionEnabled);
        for (IComponentModel iComponentModel : forest.getComponents()) {
            SamplerPlusTree samplerPlusTree = (SamplerPlusTree)iComponentModel;
            CompactSampler sampler = (CompactSampler)samplerPlusTree.getSampler();
            if (samplerStates != null) {
                samplerStates.add(samplerMapper.toState(sampler));
            }
            if (trees == null) continue;
            trees.add(samplerPlusTree.getTree());
        }
        state.setCompactSamplerStates(samplerStates);
        if (trees != null) {
            RandomCutTreeMapper treeMapper = new RandomCutTreeMapper();
            List<CompactRandomCutTreeState> list = trees.stream().map(t -> treeMapper.toState((RandomCutTree)t)).collect(Collectors.toList());
            state.setCompactRandomCutTreeStates(list);
        }
        return state;
    }

    @Override
    public RandomCutForest toModel(RandomCutForestState state, ExecutionContext executionContext, long seed) {
        ExecutionContext ec;
        if (executionContext != null) {
            ec = executionContext;
        } else {
            CommonUtils.checkNotNull(state.getExecutionContext(), "The executor context in the state object is null, an executor context must be passed explicitly to toModel()");
            ec = state.getExecutionContext();
        }
        Object builder = ((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)((RandomCutForest.Builder)RandomCutForest.builder().numberOfTrees(state.getNumberOfTrees())).dimensions(state.getDimensions())).timeDecay(state.getTimeDecay())).sampleSize(state.getSampleSize())).centerOfMassEnabled(state.isCenterOfMassEnabled())).outputAfter(state.getOutputAfter())).parallelExecutionEnabled(ec.isParallelExecutionEnabled())).threadPoolSize(ec.getThreadPoolSize())).storeSequenceIndexesEnabled(state.isStoreSequenceIndexesEnabled())).shingleSize(state.getShingleSize())).boundingBoxCacheFraction(state.getBoundingBoxCacheFraction())).compact(state.isCompact())).internalShinglingEnabled(state.isInternalShinglingEnabled())).randomSeed(seed);
        if (Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32) {
            return this.singlePrecisionForest((RandomCutForest.Builder<?>)builder, state, null, null, null);
        }
        Random random = ((RandomCutForest.Builder)builder).getRandom();
        PointStore pointStore = new PointStoreMapper().convertFromDouble(state.getPointStoreState());
        ComponentList components = new ComponentList();
        PointStoreCoordinator<float[]> coordinator = new PointStoreCoordinator<float[]>(pointStore);
        coordinator.setTotalUpdates(state.getTotalUpdates());
        CompactRandomCutTreeContext context = new CompactRandomCutTreeContext();
        context.setPointStore(pointStore);
        context.setMaxSize(state.getSampleSize());
        CommonUtils.checkArgument(state.isSaveSamplerStateEnabled(), " conversion cannot proceed without samplers");
        List<CompactSamplerState> samplerStates = state.getCompactSamplerStates();
        CompactSamplerMapper samplerMapper = new CompactSamplerMapper();
        for (int i = 0; i < state.getNumberOfTrees(); ++i) {
            CompactSampler compactData = (CompactSampler)samplerMapper.toModel(samplerStates.get(i));
            RandomCutTree tree = ((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)RandomCutTree.builder().capacity(state.getSampleSize())).pointStoreView(pointStore)).storeSequenceIndexesEnabled(state.isStoreSequenceIndexesEnabled())).outputAfter(state.getOutputAfter())).centerOfMassEnabled(state.isCenterOfMassEnabled())).randomSeed(random.nextLong())).build();
            CompactSampler sampler = ((CompactSampler.Builder)((CompactSampler.Builder)((CompactSampler.Builder)CompactSampler.builder().capacity(state.getSampleSize())).timeDecay(state.getTimeDecay())).randomSeed(random.nextLong())).build();
            sampler.setMaxSequenceIndex(compactData.getMaxSequenceIndex());
            sampler.setMostRecentTimeDecayUpdate(compactData.getMostRecentTimeDecayUpdate());
            for (Weighted<Integer> sample : compactData.getWeightedSample()) {
                Integer reference = sample.getValue();
                Integer newReference = tree.addPoint(reference, sample.getSequenceIndex());
                if (newReference.intValue() != reference.intValue()) {
                    pointStore.incrementRefCount(newReference);
                    pointStore.decrementRefCount(reference);
                }
                sampler.addPoint(newReference, sample.getWeight(), sample.getSequenceIndex());
            }
            components.add(new SamplerPlusTree<Integer, float[]>(sampler, tree));
        }
        return new RandomCutForest((RandomCutForest.Builder<?>)builder, coordinator, components, random);
    }

    @Override
    public RandomCutForest toModel(RandomCutForestState state, long seed) {
        return this.toModel(state, null, seed);
    }

    public RandomCutForest toModel(RandomCutForestState state) {
        return (RandomCutForest)this.toModel(state, null);
    }

    public RandomCutForest singlePrecisionForest(RandomCutForest.Builder<?> builder, RandomCutForestState state, IPointStore<Integer, float[]> extPointStore, List<ITree<Integer, float[]>> extTrees, List<IStreamSampler<Integer>> extSamplers) {
        CommonUtils.checkArgument(builder != null, "builder cannot be null");
        CommonUtils.checkArgument(extTrees == null || extTrees.size() == state.getNumberOfTrees(), "incorrect number of trees");
        CommonUtils.checkArgument(extSamplers == null || extSamplers.size() == state.getNumberOfTrees(), "incorrect number of samplers");
        CommonUtils.checkArgument(extSamplers != null | state.isSaveSamplerStateEnabled(), " need samplers ");
        CommonUtils.checkArgument(extPointStore != null || state.isSaveCoordinatorStateEnabled(), " need coordinator state ");
        Random random = builder.getRandom();
        ComponentList components = new ComponentList();
        CompactRandomCutTreeContext context = new CompactRandomCutTreeContext();
        IPointStore pointStore = extPointStore == null ? (IPointStore)new PointStoreMapper().toModel(state.getPointStoreState()) : extPointStore;
        PointStoreCoordinator<float[]> coordinator = new PointStoreCoordinator<float[]>(pointStore);
        coordinator.setTotalUpdates(state.getTotalUpdates());
        context.setPointStore(pointStore);
        context.setMaxSize(state.getSampleSize());
        RandomCutTreeMapper treeMapper = new RandomCutTreeMapper();
        List<CompactRandomCutTreeState> treeStates = state.isSaveTreeStateEnabled() ? state.getCompactRandomCutTreeStates() : null;
        CompactSamplerMapper samplerMapper = new CompactSamplerMapper();
        List<CompactSamplerState> samplerStates = state.isSaveSamplerStateEnabled() ? state.getCompactSamplerStates() : null;
        for (int i = 0; i < state.getNumberOfTrees(); ++i) {
            RandomCutTree tree;
            CompactSampler sampler;
            CompactSampler compactSampler = sampler = extSamplers != null ? extSamplers.get(i) : samplerMapper.toModel(samplerStates.get(i), random.nextLong());
            if (extTrees != null) {
                tree = extTrees.get(i);
            } else if (treeStates != null) {
                tree = treeMapper.toModel(treeStates.get(i), context, random.nextLong());
                sampler.getSample().forEach(s -> tree.addPointToPartialTree((Integer)s.getValue(), s.getSequenceIndex()));
                tree.setConfig("bounding_box_cache_fraction", treeStates.get(i).getBoundingBoxCacheFraction());
                tree.validateAndReconstruct();
            } else {
                tree = ((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)new RandomCutTree.Builder().capacity(state.getSampleSize())).randomSeed(random.nextLong())).pointStoreView(pointStore)).boundingBoxCacheFraction(state.getBoundingBoxCacheFraction())).centerOfMassEnabled(state.isCenterOfMassEnabled())).storeSequenceIndexesEnabled(state.isStoreSequenceIndexesEnabled())).build();
                sampler.getSample().forEach(s -> tree.addPoint((Integer)s.getValue(), s.getSequenceIndex()));
            }
            components.add(new SamplerPlusTree<Integer, float[]>(sampler, tree));
        }
        builder.precision(Precision.FLOAT_32);
        return new RandomCutForest(builder, coordinator, components, random);
    }

    @Generated
    public boolean isSaveTreeStateEnabled() {
        return this.saveTreeStateEnabled;
    }

    @Generated
    public boolean isSaveCoordinatorStateEnabled() {
        return this.saveCoordinatorStateEnabled;
    }

    @Generated
    public boolean isSaveSamplerStateEnabled() {
        return this.saveSamplerStateEnabled;
    }

    @Generated
    public boolean isSaveExecutorContextEnabled() {
        return this.saveExecutorContextEnabled;
    }

    @Generated
    public boolean isCompressionEnabled() {
        return this.compressionEnabled;
    }

    @Generated
    public boolean isPartialTreeStateEnabled() {
        return this.partialTreeStateEnabled;
    }

    @Generated
    public void setSaveTreeStateEnabled(boolean saveTreeStateEnabled) {
        this.saveTreeStateEnabled = saveTreeStateEnabled;
    }

    @Generated
    public void setSaveCoordinatorStateEnabled(boolean saveCoordinatorStateEnabled) {
        this.saveCoordinatorStateEnabled = saveCoordinatorStateEnabled;
    }

    @Generated
    public void setSaveSamplerStateEnabled(boolean saveSamplerStateEnabled) {
        this.saveSamplerStateEnabled = saveSamplerStateEnabled;
    }

    @Generated
    public void setSaveExecutorContextEnabled(boolean saveExecutorContextEnabled) {
        this.saveExecutorContextEnabled = saveExecutorContextEnabled;
    }

    @Generated
    public void setCompressionEnabled(boolean compressionEnabled) {
        this.compressionEnabled = compressionEnabled;
    }

    @Generated
    public void setPartialTreeStateEnabled(boolean partialTreeStateEnabled) {
        this.partialTreeStateEnabled = partialTreeStateEnabled;
    }
}

