// Copyright 2022-2024 Luminary Cloud, Inc. All Rights Reserved.
import { CallbackInterface, SetterOrUpdater } from 'recoil';

import { CadMetadata } from '../proto/cadmetadata/cadmetadata_pb';
import * as simulationpb from '../proto/client/simulation_pb';
import { EntityType } from '../proto/entitygroup/entitygroup_pb';
import * as feoutputpb from '../proto/frontend/output/output_pb';
import * as meshgenerationpb from '../proto/meshgeneration/meshgeneration_pb';
import { ViewState } from '../pvproto/ParaviewRpc';
import { EntityGroupData, entityGroupDataSelector } from '../recoil/entityGroupState';
import { extractSurfaceState } from '../recoil/extractSurfacesState';
import { GeometryState, geometryState } from '../recoil/geometry/geometryState';
import { GeometryTags } from '../recoil/geometry/geometryTagsObject';
import { geometryTagsState } from '../recoil/geometry/geometryTagsState';
import { FrontendGeometryContacts, geometryContactsStateSelector } from '../recoil/geometryContactsState';
import { lcVisEnabledSelector } from '../recoil/lcvis/lcvisEnabledState';
import { outputNodesState } from '../recoil/outputNodes';
import { SimulationSubselect } from '../recoil/simulationTreeSubselect';
import { cadMetadataState } from '../recoil/useCadMetadata';
import { enabledExperimentsState } from '../recoil/useExperimentConfig';
import { meshGenerationState } from '../recoil/useMeshGeneration';
import { meshingMultiPartState } from '../recoil/useMeshingMultiPart';
import { viewStateAtomFamily } from '../recoil/useViewState';
import { StaticVolume, staticVolumesState } from '../recoil/volumes';
import { currentConfigSelector } from '../recoil/workflowConfig';

import { createParamScope } from './ParamScope';
import {
  findFluidBoundaryCondition,
  findHeatBoundaryCondition,
  findPeriodicPairById,
} from './boundaryConditionUtils';
import { findContactById } from './contactsUtils';
import { makeSetterOrUpdater } from './contextUtils';
import { EntityGroupMap } from './entityGroupMap';
import { bucketGroupsByType, expandGroups, rollupGroups, unwrapSurfaceIds } from './entityGroupUtils';
import { getAssignedMaterialDomains, getPhysicsDomains } from './entityRelationships';
import { cadIdsToVolumeNodeIds } from './geometryUtils';
import { findHeatSourceById } from './heatSourceUtils';
import { convertToPvTreeNodeNames } from './imposterFilteringUtils';
import { subtractArray, toggleSetItem } from './lang';
import {
  conflictingMeshParamVolumes,
  nullableMeshing,
  sizeHeading,
} from './mesh';
import { findMultiphysicsInterfaceById } from './multiphysicsInterfaceUtils';
import { NodeTableIdentifier, NodeTableType, SURFACES_TABLES } from './nodeTableUtil';
import { findOutputNodeById, getOutputNodeWarnings } from './outputNodeUtils';
import { OutputGraphId } from './outputUtils';
import { RecoilProjectKey } from './persist';
import {
  PHYSICAL_BEHAVIOR_LABEL,
  getParticleGroupMapByPhysicalBehavior,
  setParticleGroupsForBehavior,
} from './physicalBehaviorUtils';
import { parsePhysicsIdFromSubId } from './physicsUtils';
import { findPorousModelById } from './porousModelUtils';
import { getSimulationParam } from './simulationParamUtils';
import { NodeType, SimulationTreeNode } from './simulationTree/node';
import {
  getNonDiskWarning,
  setSlidingInterfaceSurfaces,
} from './simulationUtils';
import { findSlidingInterfaceById } from './slidingInterfaceUtils';
import { wordsToList } from './text';
import {
  mapDomainsToIds,
  mapIdsToDomains,
  mapIndicestoIds,
  mapVolumeIdsToIndices,
} from './volumeUtils';

export type SelectionContextType = {
  param: simulationpb.SimulationParam,
  viewState?: ViewState,
  setOutputNodes: SetterOrUpdater<feoutputpb.OutputNodes>,
  outputNodes: feoutputpb.OutputNodes,
  setExtractSurfacesList: SetterOrUpdater<string[]>,
  extractSurfacesList: string[],
  entityGroupData: EntityGroupData,
  meshMultiPart: nullableMeshing,
  setMeshMultiPart: SetterOrUpdater<nullableMeshing>,
  experimentConfig: string[],
  staticVolumes: StaticVolume[],
  contacts: FrontendGeometryContacts,
  lcvisEnabled?: boolean,
  geometryTags: GeometryTags,
  geometryState: GeometryState,
  cadMetadata: CadMetadata,
};

export type NodeContextType = {
  selectedNode: SimulationTreeNode | null;
  selectedNodeIds: string[];
  outputGraphId: OutputGraphId;
  setOutputGraphId: (id: OutputGraphId) => void;
  activeNodeTable: NodeTableIdentifier;
  setActiveNodeTable: (activeTable: NodeTableIdentifier) => void;
  // indicates that the tree is in a modal state, via NodeTable or NodeSubselect
  isTreeModal: boolean;
  nodeTableWarning: string;
  setNodeTableWarning: (warning: string) => void;
  highlightedInVisualizer: string[];
  setHighlightedInVisualizer: (ids: string[]) => void;
  highlightedInSimTree: Set<string>;
  treeSubselect: SimulationSubselect;
  // A hack to guarantee a smooth transition out of subselect mode
  setFinishingSubselectNodeIds: (ids: string[]) => void;
};

export type Contexts = {
  providedSelection: NodeContextType;
  selection: SelectionContextType;
};

export enum SelectionAction {
  // Add to the current selection.
  ADD,
  // Subtract from the current selection.
  SUBTRACT,
  // Overwrite the current selection entirely.
  OVERWRITE,
  // Overwrite the current selection and remove the selection from other options.
  OVERWRITE_EXCLUDE,
  // Toggle the selection, IDs in the selection will be removed. IDs not in the
  // current selection will be added.
  TOGGLE,
  // Only determine which nodes should be highlighted based on the current selection
  HIGHLIGHT_CURRENT,
}

// Optionally accept only the required contexts that are needed for getting selection
export interface GetCurrentSelectionContexts {
  providedSelection: Pick<
    NodeContextType,
    'activeNodeTable' | 'selectedNode' | 'selectedNodeIds' | 'treeSubselect'
  >;
  selection: Pick<
    SelectionContextType,
    'meshMultiPart' | 'experimentConfig' | 'outputNodes' | 'param' | 'extractSurfacesList' |
    'staticVolumes'
  >;
}

// This is the list of tables that require us to ungroup to individual surfaces.
export const UNGROUP_TABLES = [
  NodeTableType.EXTRACT_BOUNDARY,
  NodeTableType.SLIDING_INTERFACE_SURFACES_A,
  NodeTableType.SLIDING_INTERFACE_SURFACES_B,
  NodeTableType.PHYSICAL_BEHAVIOR_ATTACH,
  NodeTableType.SENSITIVITY_SURFACES,
];

// A list of tables that select volumes
export const VOLUME_TABLES = [
  NodeTableType.MESHING_SIZE,
];

const SelectionTypes = {
  model: meshgenerationpb.MeshingMultiPart_ModelParams_SelectionType,
  boundary: meshgenerationpb.MeshingMultiPart_BoundaryLayerParams_SelectionType,
  volume: meshgenerationpb.MeshingMultiPart_VolumeParams_SelectionType,
};

// Returns the current selection based on which type of node table is active.
// If no table is active, returns the selectedNodeIds.
export function getCurrentSelection(contexts: Contexts | GetCurrentSelectionContexts): string[] {
  const { meshMultiPart, outputNodes, param, staticVolumes } = contexts.selection;
  const {
    activeNodeTable,
    selectedNode,
    selectedNodeIds,
    treeSubselect,
  } = contexts.providedSelection;

  if (treeSubselect.active) {
    return selectedNodeIds;
  }

  switch (activeNodeTable.type) {
    case NodeTableType.NONE:
      return selectedNodeIds;
    case NodeTableType.POINTS:
    case NodeTableType.VOLUMES:
      return selectedNode ?
        findOutputNodeById(outputNodes, selectedNode.id)?.inSurfaces || [] :
        [];
    case NodeTableType.MESHING_SIZE: {
      const index = activeNodeTable.index!;
      if (!meshMultiPart) {
        return [];
      }
      return mapIndicestoIds(staticVolumes, meshMultiPart.volumeParams[index].volumes);
    }
    case NodeTableType.PHYSICAL_BEHAVIOR_ATTACH: {
      if (!selectedNode) {
        throw Error(`No selected node in ${PHYSICAL_BEHAVIOR_LABEL} attach`);
      }
      const particleGroupMap = getParticleGroupMapByPhysicalBehavior(param, selectedNode.id);
      return Object.keys(particleGroupMap);
    }
    case NodeTableType.EXTRACT_BOUNDARY: {
      const { extractSurfacesList } = contexts.selection;
      if (extractSurfacesList) {
        return extractSurfacesList;
      }
      return [];
    }
    case NodeTableType.SLIDING_INTERFACE_SURFACES_A: {
      const intf = selectedNode ? findSlidingInterfaceById(param, selectedNode.id) : null;
      return intf?.slidingA || [];
    }
    case NodeTableType.SLIDING_INTERFACE_SURFACES_B: {
      const intf = selectedNode ? findSlidingInterfaceById(param, selectedNode.id) : null;
      return intf?.slidingB || [];
    }

    case NodeTableType.SENSITIVITY_SURFACES: {
      return param.adjoint?.surfaces ?? [];
    }
    default:
      throw Error('Active node table is not defined.');
  }
}

/**
 * Apply some change to the current selection. This uses the modification IDs to add, subtract,
 * toggle, or overwrite the current selection.
 * @param currentSelection
 * @param modificationIds
 * @param action
 * @returns
 */
export function applyAction(
  currentSelection: string[],
  modificationIds: string[],
  action: SelectionAction,
): string[] {
  let newSelection: string[] = [];
  if (action === SelectionAction.OVERWRITE || action === SelectionAction.OVERWRITE_EXCLUDE) {
    newSelection = modificationIds.slice();
  } else {
    newSelection = currentSelection.slice();
    const applyAdd = action === SelectionAction.ADD || action === SelectionAction.TOGGLE;
    const applyMinus = action === SelectionAction.SUBTRACT || action === SelectionAction.TOGGLE;
    if (applyMinus) {
      newSelection = newSelection.filter((id) => !modificationIds.includes(id));
    }
    if (applyAdd) {
      modificationIds.forEach((id) => {
        if (!currentSelection.includes(id)) {
          newSelection.push(id);
        }
      });
    }
  }
  newSelection.sort();
  return newSelection;
}

/**
 * Where applyAction operates on node IDs as independent entities, applyActionToLeafs considers
 * groups.  For instance if 'group1' = ['nodeA', 'nodeB'], then toggling 'group1' means (and is a
 * proxy for) toggling 'nodeA' and toggling 'nodeB'.
 * @param currentSelection
 * @param modificationIds
 * @param action
 * @param entityGroupMap
 * @returns
 */
export function applyActionToLeafs(
  currentSelection: string[],
  modificationIds: string[],
  action: SelectionAction,
  entityGroupData: EntityGroupData,
): string[] {
  const { leafMap } = entityGroupData;
  const overwrite = [SelectionAction.OVERWRITE, SelectionAction.OVERWRITE_EXCLUDE].includes(action);

  const newSelectionSet = new Set(overwrite ? [] : expandGroups(leafMap)(currentSelection));

  modificationIds.forEach((id) => {
    const leafIds = leafMap.has(id) ? leafMap.get(id)! : [id];

    leafIds.forEach((leafId) => {
      switch (action) {
        case SelectionAction.ADD:
        case SelectionAction.OVERWRITE:
          newSelectionSet.add(leafId);
          break;
        case SelectionAction.SUBTRACT:
          newSelectionSet.delete(leafId);
          break;
        case SelectionAction.TOGGLE:
          toggleSetItem(newSelectionSet, leafId);
          break;
        default:
        // no default
      }
    });
  });

  const newSelection = [...newSelectionSet];
  newSelection.sort();
  return newSelection;
}

// Entity types to consider for click evaluation and grouping
export const allowableEntityTypes = new Map<NodeTableType, EntityType[]>([
  [NodeTableType.PHYSICAL_BEHAVIOR_ATTACH, [EntityType.PARTICLE_GROUP]],
  [NodeTableType.POINTS, [EntityType.PROBE_POINTS]],
  [NodeTableType.VOLUMES, [EntityType.VOLUME, EntityType.BODY_TAG]],
  [NodeTableType.EXTRACT_BOUNDARY, [EntityType.SURFACE, EntityType.PARTICLE_GROUP]],
  [NodeTableType.SLIDING_INTERFACE_SURFACES_A, [EntityType.SURFACE]],
  [NodeTableType.SLIDING_INTERFACE_SURFACES_B, [EntityType.SURFACE]],
  [NodeTableType.SENSITIVITY_SURFACES, [EntityType.SURFACE]],
]);

// Returns a string containing the names of incompatible item(s) for a group and a node table type
export const groupsNotAllowedWarning = (
  nodeTableType: NodeTableType,
  groupId: string,
  entityGroupMap: EntityGroupMap,
) => {
  const group = entityGroupMap.get(groupId);
  const allowedEntityTypes = allowableEntityTypes.get(nodeTableType);
  if (allowedEntityTypes?.includes(entityGroupMap.get(groupId).entityType)) {
    return '';
  }
  if (group.children.size) {
    const groupNames: string[] = [];
    Array.from(group.children).forEach((childId) => {
      if (!allowedEntityTypes?.includes(entityGroupMap.get(childId).entityType)) {
        groupNames.push(entityGroupMap.get(childId).name);
      }
    });
    return `Cannot add ${group.name} because ${wordsToList(groupNames)}
    ${groupNames.length > 1 ? 'are' : 'is'} incompatible.`;
  }
  return `Cannot add ${group.name} because it is incompatible.`;
};

// Checks whether selecting a node is allowed with the current active node table. Returns a warning
// message if the node is not allowed otherwise returns an empty string.
export const allowedSelection = (
  id: string,
  selectedNodeId: string | undefined,
  nodeTableType: NodeTableType,
  param: simulationpb.SimulationParam,
  entityGroupMap: EntityGroupMap,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
) => {
  if (nodeTableType === NodeTableType.NONE) {
    return '';
  }

  if (VOLUME_TABLES.includes(nodeTableType)) {
    // Volume tables expect a Volume which has a number as its ID.
    return '';
  }
  const entityNode = entityGroupMap.has(id) && entityGroupMap.get(id);
  const allowedEntityType = (
    entityNode &&
    allowableEntityTypes.get(nodeTableType)?.includes(entityNode.entityType)
  );
  if (!allowedEntityType) {
    if (entityNode) {
      return groupsNotAllowedWarning(nodeTableType, entityNode.id, entityGroupMap);
    }
    return 'Cannot add item.';
  }

  // Validate that we can add a tag to the current selection based on whether we need surfaces or
  // volumes.
  if (geometryTags.isTagId(entityNode.id)) {
    const nBodies = geometryTags.domainsFromTagEntityGroupId(entityNode.id)?.length;
    const nFaces = geometryTags.surfacesFromTagEntityGroupId(entityNode.id)?.length;
    if (VOLUME_TABLES.includes(nodeTableType) && nBodies && nBodies === 0) {
      return 'Cannot add a tag without volumes.';
    }
    if (SURFACES_TABLES.includes(nodeTableType) && nFaces && nFaces === 0) {
      return 'Cannot add a tag without surfaces.';
    }
  }

  return '';
};

// Applies a new selection. Does any necessary changes to apply this new
// selection to whichever node table is currently active. If none are active,
// it changes selected node IDs. If it is exclusive, remove the selection from
// the other alternatives.
export function applyNewSelection(
  newSelection: string[],
  contexts: Contexts,
  isExclusive: boolean,
  setSelectedNodeIds: (nodeIds: string[]) => void,
  onParamUpdate: (newParam: simulationpb.SimulationParam) => void,
): string[] {
  const {
    entityGroupData,
    meshMultiPart,
    outputNodes,
    param,
    setExtractSurfacesList,
    setMeshMultiPart,
    setOutputNodes,
    staticVolumes,
    geometryTags,
  } = contexts.selection;
  const {
    activeNodeTable,
    selectedNode,
    setNodeTableWarning,
    treeSubselect,
  } = contexts.providedSelection;

  if (treeSubselect.active) {
    setSelectedNodeIds(newSelection);
    return [];
  }

  const warnings = newSelection.map(
    (id) => allowedSelection(
      id,
      selectedNode?.id,
      activeNodeTable.type,
      param,
      entityGroupData.groupMap,
      geometryTags,
      staticVolumes,
    ),
  ).filter((warning) => warning.length);

  // If there are warnings because a node is not allowed for the current selection return it.
  if (warnings.length) {
    return warnings;
  }

  switch (activeNodeTable.type) {
    case NodeTableType.NONE:
      setSelectedNodeIds(newSelection);
      break;
    case NodeTableType.POINTS:
    case NodeTableType.VOLUMES: {
      const outputNode = selectedNode ? findOutputNodeById(outputNodes, selectedNode.id) : null;
      if (!outputNode) {
        throw Error('Error updating output');
      }
      const outputNodeWarning = getNonDiskWarning(
        newSelection,
        outputNode,
        contexts.selection.param,
        contexts.selection.entityGroupData.groupMap,
      );
      if (outputNodeWarning) {
        return [outputNodeWarning];
      }
      // Only update the table if no warning should be shown
      setOutputNodes((oldOutputNodes: feoutputpb.OutputNodes) => {
        const outputList = oldOutputNodes.nodes.slice();
        const index = outputList.indexOf(outputNode);
        const newOutput = outputNode.clone();
        if (
          [NodeTableType.POINTS, NodeTableType.VOLUMES].includes(
            activeNodeTable.type,
          )
        ) {
          newOutput.inSurfaces = newSelection;
        } else {
          newOutput.outSurfaces = newSelection;
        }
        outputList[index] = newOutput;
        const newOutputNodes = oldOutputNodes.clone();
        newOutputNodes.nodes = outputList;
        return newOutputNodes;
      });
      break;
    }
    case NodeTableType.MESHING_SIZE: {
      if (!meshMultiPart) {
        return [];
      }
      const index = activeNodeTable.index!;
      const overlapIndices = conflictingMeshParamVolumes(
        index,
        newSelection,
        meshMultiPart,
        staticVolumes,
      );
      const headings = overlapIndices.map((idx) => sizeHeading(idx));
      setNodeTableWarning(
        overlapIndices.length ? `This will remove some volumes from ${wordsToList(headings)}.` : '',
      );
      setMeshMultiPart((oldMeshMultiPart: nullableMeshing) => {
        if (!oldMeshMultiPart) {
          return null;
        }
        const newMultiPart = oldMeshMultiPart.clone();
        const paramsList = newMultiPart.volumeParams;
        const indices = mapVolumeIdsToIndices(newSelection, staticVolumes);
        paramsList[index].volumes = indices;
        if (isExclusive) {
          paramsList.forEach((params, i) => {
            if (i !== index) {
              params.volumes = subtractArray(params.volumes, indices);
            }
          });
          // Switch to selected if any of the default surfaces are removed.
          if (overlapIndices.includes(0)) {
            paramsList[0].selection = SelectionTypes.volume.SELECTED;
          }
        }
        newMultiPart.volumeParams = paramsList;
        return newMultiPart;
      });
      break;
    }
    case NodeTableType.PHYSICAL_BEHAVIOR_ATTACH: {
      if (!selectedNode || !onParamUpdate) {
        throw Error(`Error updating ${PHYSICAL_BEHAVIOR_LABEL} attach`);
      }
      const newParam = param.clone();
      const warning = setParticleGroupsForBehavior(
        newSelection,
        selectedNode.id,
        newParam,
      );
      if (warning) {
        return [warning];
      }
      onParamUpdate(newParam);
      break;
    }
    case NodeTableType.EXTRACT_BOUNDARY: {
      if (!setExtractSurfacesList) {
        throw Error('Error updating Extract Boundary');
      }
      setExtractSurfacesList(newSelection);
      break;
    }
    case NodeTableType.SLIDING_INTERFACE_SURFACES_A:
    case NodeTableType.SLIDING_INTERFACE_SURFACES_B: {
      const sideA = (activeNodeTable.type === NodeTableType.SLIDING_INTERFACE_SURFACES_A);

      if (!onParamUpdate) {
        throw Error('Error updating motion interfaces');
      }
      if (!selectedNode) {
        throw Error('No selected node');
      }
      const newParam = param.clone();
      const warning = setSlidingInterfaceSurfaces(
        newSelection,
        selectedNode?.id,
        sideA,
        newParam,
        geometryTags,
        staticVolumes,
        entityGroupData,
        false,
      );
      if (sideA) {
        setNodeTableWarning(warning);
      }
      if (warning) {
        return [warning];
      }
      onParamUpdate(newParam);
      break;
    }

    case NodeTableType.SENSITIVITY_SURFACES: {
      const newParam = param.clone();
      newParam.adjoint!.surfaces = newSelection;
      onParamUpdate(newParam);
      break;
    }
    default:
      throw Error('Active node table is not defined.');
  }
  return [];
}

// maybeUngroupSurfaces ungroups any surface group into individual surfaces, if the selection type
// requires it. For FLUID_BOUNDARY and HEAT_BOUNDARY NodeTables, the
// list is expected to only contain individual surfaces. For the other NodeTables, the list may
// contain both individual surfaces and surface groups.
export function maybeUngroupSurfaces(
  modificationIds: string[],
  contexts: Contexts,
): string[] {
  const { entityGroupData, geometryTags } = contexts.selection;
  const { activeNodeTable } = contexts.providedSelection;
  if (UNGROUP_TABLES.includes(activeNodeTable.type)) {
    const removed: string[] = [];
    const removeTags = modificationIds.filter((id) => {
      const isTagId = geometryTags.isTagId(id);
      if (isTagId) {
        removed.push(id);
      }
      return !isTagId;
    });
    const newTags = expandGroups(entityGroupData.leafMap)(removeTags);
    return [...newTags, ...removed];
  }
  return modificationIds;
}

// Node types that can be selected together using shift or ctrl
const multiSelectGroups = [
  [EntityType.SURFACE, EntityType.MIXED, EntityType.PARTICLE_GROUP],
  [EntityType.PROBE_POINTS],
  [EntityType.VOLUME],
];
export const allowableMultiSelect = new Map(
  multiSelectGroups.flatMap((group) => group.map((type) => [type, group])),
);

export const highlightGroups = (
  groups: string[],
  contexts: Contexts,
) => {
  const {
    entityGroupData,
    viewState, lcvisEnabled, staticVolumes, geometryTags,
  } = contexts.selection;

  const highlighted: string[] = [];
  // A group can potentially hold multiple types that are handled differently. We have to
  // expand and split the group into entities of the same type.
  // Create a map that contains the list of ids for every type in this group.
  const nodeIds = expandGroups(entityGroupData.leafMap)(groups);
  const expandedByType = bucketGroupsByType(nodeIds, entityGroupData.groupMap);

  // For surfaces we can directly push the node id
  if (expandedByType.get(EntityType.SURFACE)) {
    highlighted.push(...expandedByType.get(EntityType.SURFACE)!);
  }
  if (expandedByType.get(EntityType.VOLUME)) {
    const volumeIds = expandedByType.get(EntityType.VOLUME)!;
    const indexes = mapIdsToDomains(staticVolumes, volumeIds);

    highlighted.push(...indexes);
  }

  if (expandedByType.has(EntityType.FACE_TAG)) {
    const faceTags = expandedByType.get(EntityType.FACE_TAG);
    faceTags?.forEach((faceTag) => {
      const id = geometryTags.surfaceFromTagEntityGroupId(faceTag);
      if (id) {
        highlighted.push(id);
      }
    });
  }
  if (expandedByType.has(EntityType.BODY_TAG)) {
    const bodyTags = expandedByType.get(EntityType.BODY_TAG)!;
    bodyTags.forEach((bodyTag) => {
      const domainIdFromTag = geometryTags.domainFromTagEntityGroupId(bodyTag);
      if (domainIdFromTag) {
        highlighted.push(domainIdFromTag);
      }
    });
  }
  const particleGroups = expandedByType.get(EntityType.PARTICLE_GROUP) ?? [];
  const monitorPlanes = expandedByType.get(EntityType.MONITOR_PLANE) ?? [];
  const probePoints = expandedByType.get(EntityType.PROBE_POINTS) ?? [];
  const expanded = [...particleGroups, ...monitorPlanes, ...probePoints];
  if (lcvisEnabled) {
    // if lcvisEnabled, use the actual ids of the entities.
    highlighted.push(...expanded);
  } else if (viewState?.root && expanded.length > 0) {
    // For particles, planes and probes we have to convert the ids to the pv tree node names
    highlighted.push(...convertToPvTreeNodeNames(expanded, viewState.root));
  }
  return highlighted;
};

// Returns a list of entity ids that should be highlighted in the visualizer based on the
// currently selected node or active node table
export const highlightedEntities = (
  newSelectedNode: SimulationTreeNode | null,
  selectedNodeIds: string[],
  contexts: Contexts,
) => {
  const {
    entityGroupData,
    outputNodes,
    param: simParam,
    viewState,
    contacts,
    lcvisEnabled,
    staticVolumes,
    geometryTags,
    geometryState: geometryStateIn,
    cadMetadata,
  } = contexts.selection;
  const { activeNodeTable } = contexts.providedSelection;

  const expandGroupsCallback = expandGroups(entityGroupData.leafMap);

  // For some node tables we have to translate the current selection to the entitity ids
  if (VOLUME_TABLES.includes(activeNodeTable.type)) {
    return mapIdsToDomains(staticVolumes, selectedNodeIds);
  }

  // If more than one node is in the new selection and no node table is active that requires special
  // treatment we highlight those that can be found in the group map.
  if (!newSelectedNode) {
    return highlightGroups(
      selectedNodeIds.filter((id) => entityGroupData.groupMap.has(id)),
      contexts,
    );
  }
  // For a single selected node we usually highlight other associated entities based on the
  // node type
  switch (newSelectedNode.type) {
    case NodeType.OUTPUT: {
      const output = outputNodes.nodes.find(
        (outputA: feoutputpb.OutputNode) => outputA.id === newSelectedNode.id,
      );
      const paramScope = createParamScope(simParam, []);
      if (output && !getOutputNodeWarnings(
        output,
        outputNodes,
        simParam,
        entityGroupData,
        paramScope,
        staticVolumes,
        geometryTags,
        outputNodes.referenceValues?.referenceValueType,
      ).length) {
        const outputIds = output.inSurfaces;
        return highlightGroups(outputIds, contexts);
      }
      return [];
    }
    case NodeType.SURFACE:
    case NodeType.SURFACE_GROUP:
    case NodeType.MONITOR_PLANE:
    case NodeType.PARTICLE_GROUP:
    case NodeType.PROBE_POINT:
    case NodeType.TAGS_FACE:
      return highlightGroups([newSelectedNode.id], contexts);
    case NodeType.VOLUME:
      // For volumes, the id is of the form 'volume-' followed by the index.
      // Paraview is expecting the string of the volume index.
      return mapIdsToDomains(staticVolumes, [newSelectedNode.id]);
    case NodeType.TAGS_BODY: {
      const domain = geometryTags.domainFromTagEntityGroupId(newSelectedNode.id);
      return domain ? [domain] : [];
    }
    case NodeType.FAR_FIELD: {
      const childIds = newSelectedNode.children.map((node) => node.id);
      return expandGroupsCallback(childIds);
    }
    case NodeType.PHYSICS_FLUID_BOUNDARY_CONDITION:
    case NodeType.PHYSICS_HEAT_BOUNDARY_CONDITION: {
      const boundary = (
        newSelectedNode.type === NodeType.PHYSICS_HEAT_BOUNDARY_CONDITION ?
          findHeatBoundaryCondition(simParam, newSelectedNode.id) :
          findFluidBoundaryCondition(simParam, newSelectedNode.id)
      );
      return unwrapSurfaceIds(boundary?.surfaces || [], geometryTags, entityGroupData);
    }
    case NodeType.PHYSICAL_BEHAVIOR:
      if (viewState?.root || lcvisEnabled) {
        const porousModel = findPorousModelById(simParam, newSelectedNode.id);
        if (porousModel) {
          return porousModel.zoneIds;
        }
        // A physical behavior can have multiple particle groups attached to it,
        // so we have to find all the ids of the groups attached to the behavior.
        // Then, find all the imposters associated with the particle groups.
        const particleGroupMap = getParticleGroupMapByPhysicalBehavior(
          simParam,
          newSelectedNode.id,
        );
        const particleGroupIds = Object.keys(particleGroupMap);
        if (lcvisEnabled) {
          return particleGroupIds;
        }
        return convertToPvTreeNodeNames(particleGroupIds, viewState!.root);
      }
      return [];
    case NodeType.POROUS_MODEL: {
      const porousModel = findPorousModelById(simParam, newSelectedNode.id);
      const volumes = porousModel?.zoneIds.flatMap((zoneId) => {
        const domains = geometryTags.domainsFromTag(zoneId);
        if (domains.length > 0) {
          return mapDomainsToIds(staticVolumes, domains);
        }
        return zoneId;
      });
      return volumes ?? [];
    }
    case NodeType.PHYSICS_PERIODIC_PAIR: {
      // Highlight boundA and boundB from the periodic pair proto definition.  Note that the
      // periodic bounds do not use groups yet.
      const pair = findPeriodicPairById(simParam, newSelectedNode.id);
      if (pair) {
        return unwrapSurfaceIds([...pair.boundA, ...pair.boundB], geometryTags, entityGroupData);
      }
      return [];
    }
    case NodeType.PHYSICS_SLIDING_INTERFACE: {
      // Highlight all the surfaces defined on sides A and B of interfaces. Note that these
      // interfaces allow inputting groups and so we have to expand such groups since ParaView
      // is not aware of the grouping.
      const slidingInterface = findSlidingInterfaceById(simParam, newSelectedNode.id);
      const surfaceList = slidingInterface ? [
        ...slidingInterface.slidingA, ...slidingInterface.slidingB,
      ] : [];
      return unwrapSurfaceIds(surfaceList, geometryTags, entityGroupData);
    }
    case NodeType.GEOMETRY_CONTACT: {
      // Highlight all the surfaces defined on sides A and B of interfaces. Note that these
      // interfaces allow inputting groups and so we have to expand such groups since ParaView
      // is not aware of the grouping.
      const contact = findContactById(contacts, newSelectedNode.id);
      const surfaceList = [
        ...contact?.sideA || [],
        ...contact?.sideB || [],
      ];
      return expandGroupsCallback(surfaceList);
    }
    case NodeType.FILTER:
      return [newSelectedNode.id];
    case NodeType.PHYSICS_FLUID:
    case NodeType.PHYSICS_HEAT: {
      return [...getPhysicsDomains(simParam, newSelectedNode.id, geometryTags, staticVolumes)];
    }
    case NodeType.PHYSICS_VOLUME_SELECTION: {
      const physicsId = parsePhysicsIdFromSubId(newSelectedNode.id);
      return [...getPhysicsDomains(simParam, physicsId, geometryTags, staticVolumes)];
    }
    case NodeType.PHYSICS_HEAT_HEAT_SOURCE: {
      const zoneIds = findHeatSourceById(simParam, newSelectedNode.id)?.heatSourceZoneIds || [];
      return zoneIds.flatMap((zoneId) => {
        const domains = geometryTags.domainsFromTag(zoneId);
        if (domains.length > 0) {
          return mapDomainsToIds(staticVolumes, domains);
        }
        return zoneId;
      });
    }
    case NodeType.PHYSICS_MULTI_INTERFACE: {
      const couplingInterface = findMultiphysicsInterfaceById(simParam, newSelectedNode.id);
      return unwrapSurfaceIds([
        ...couplingInterface?.slidingA || [],
        ...couplingInterface?.slidingB || [],
      ], geometryTags, entityGroupData);
    }
    case NodeType.MATERIAL_FLUID:
    case NodeType.MATERIAL_SOLID: {
      const domains = getAssignedMaterialDomains(simParam, geometryTags, newSelectedNode.id, true);
      return Array.from(domains);
    }
    case NodeType.GEOMETRY_MODIFICATION: {
      const features = geometryStateIn.geometryFeatures;
      const feature = features.find((feat) => feat.id === newSelectedNode.id);
      if (!feature) {
        return [];
      }
      switch (feature.operation.case) {
        case 'boolean':
          switch (feature.operation.value.op.case) {
            case 'regUnion':
            case 'regIntersection': {
              const cadIds = feature.operation.value.op.value.bodies;
              const volumeIds = cadIdsToVolumeNodeIds(cadIds, staticVolumes, cadMetadata);
              return mapIdsToDomains(staticVolumes, volumeIds);
            }
            case 'regSubtraction':
            case 'regChop': {
              const cadIdsBodies = feature.operation.value.op.value.bodies;
              const cadIdsTools = feature.operation.value.op.value.tools;
              const volumeIdsBodies =
                cadIdsToVolumeNodeIds(cadIdsBodies, staticVolumes, cadMetadata);
              const volumeIdsTools = cadIdsToVolumeNodeIds(cadIdsTools, staticVolumes, cadMetadata);
              const domainsBodies = mapIdsToDomains(staticVolumes, volumeIdsBodies);
              const domainsTools = mapIdsToDomains(staticVolumes, volumeIdsTools);
              return Array.from(new Set([...domainsBodies, ...domainsTools]));
            }
            default:
              return [];
          }
        case 'transform': {
          const cadIds = feature.operation.value.body;
          const volumeIds = cadIdsToVolumeNodeIds(cadIds, staticVolumes, cadMetadata);
          return mapIdsToDomains(staticVolumes, volumeIds);
        }
        default:
          break;
      }
      return [];
    }
    default:
      // By default highlight the selected node if it can be found in the group map.
      if (entityGroupData.groupMap.has(newSelectedNode.id)) {
        return highlightGroups([newSelectedNode.id], contexts);
      }
      return [];
  }
};

export const getSelectionContext = async (
  recoilKey: RecoilProjectKey,
  cbInterface: CallbackInterface,
  geometryId: string,
) => {
  const { snapshot: { getPromise }, set } = cbInterface;
  const { projectId } = recoilKey;
  const setOutputNodes = makeSetterOrUpdater(
    set,
    outputNodesState({ projectId, workflowId: '', jobId: '' }),
  );
  const outputNodes = await getPromise(outputNodesState({ projectId, workflowId: '', jobId: '' }));

  const config = await getPromise(currentConfigSelector(recoilKey));
  const param = getSimulationParam(config);
  const meshGeneration = await getPromise(meshGenerationState(projectId));
  const meshMultiPart = await getPromise(meshingMultiPartState(projectId));
  const setMeshMultiPart = makeSetterOrUpdater(set, meshingMultiPartState(projectId));

  const entityGroupData = await getPromise(entityGroupDataSelector(recoilKey));
  const setExtractSurfacesList = makeSetterOrUpdater(set, extractSurfaceState);
  const extractSurfacesList = await getPromise(extractSurfaceState);
  const viewState = await getPromise(viewStateAtomFamily(projectId));
  const experimentConfig = await getPromise(enabledExperimentsState);
  const staticVolumes = await getPromise(staticVolumesState(projectId));
  const contacts = await getPromise(geometryContactsStateSelector({ projectId }));
  const lcvisEnabled = await getPromise(lcVisEnabledSelector(projectId));
  const geometryTags = await getPromise(geometryTagsState({ projectId }));
  const geometryStateRecoil = await getPromise(geometryState({ projectId, geometryId }));
  const cadMetadata = await getPromise(cadMetadataState(projectId));

  return {
    param,
    outputNodes,
    setOutputNodes,
    meshGeneration,
    entityGroupData,
    setExtractSurfacesList,
    extractSurfacesList,
    viewState: viewState ?? undefined,
    meshMultiPart,
    setMeshMultiPart,
    experimentConfig,
    staticVolumes,
    contacts,
    lcvisEnabled,
    geometryTags,
    geometryState: geometryStateRecoil,
    cadMetadata,
  } as SelectionContextType;
};

/**
 * Apply tree selections in subselect mode and call the subselect's onChange
 * @param treeSubselect
 * @param currentSelection
 * @param modificationIds
 * @param action
 * @param entityGroupData
 * @returns transformed selections (including rolled up surface groups if indicated)
 */
export function applySubselections(
  treeSubselect: SimulationSubselect,
  currentSelection: string[],
  modificationIds: string[],
  action: SelectionAction,
  entityGroupData: EntityGroupData,
) {
  if (treeSubselect.independentSelection) {
    // "Independent" subselection mode means that groups and their children are treated as
    // independent selections (i.e. no rollup or expand calls).  For instance, a surface
    // group and one of its child surfaces may each be subselected independently (such as in
    // surface outputs).  Use `applyAction` to get new selection IDs.
    const newSelection = applyAction(currentSelection, modificationIds, action);
    treeSubselect.onChange(newSelection);
    return newSelection;
  }
  // Otherwise, groups should serve as proxies for their descendants; selecting a group
  // means selecting its leaf descendants only, and if all leaf descendants of a group are
  // selected, then the group should be selected instead.  Use `applyActionToLeafs`
  // instead of `applyAction` to get the set of leaf node IDs to pass to the subselect's
  // onChange function and then roll the selection into groups for subsequent
  // highlighting.
  const newSelection = applyActionToLeafs(
    currentSelection,
    modificationIds,
    action,
    entityGroupData,
  );
  treeSubselect.onChange(newSelection);
  return rollupGroups(entityGroupData)(newSelection);
}
