diff --git a/src/Selection.tsx b/src/Selection.tsx index 5e3ae36..43ef77e 100644 --- a/src/Selection.tsx +++ b/src/Selection.tsx @@ -1,6 +1,7 @@ import * as THREE from 'three' -import React, { createContext, useState, useContext, useEffect, useRef, useMemo } from 'react' -import { type ThreeElements } from '@react-three/fiber' +import type React from 'react' +import { createContext, useState, useContext, useEffect, useRef, useMemo } from 'react' +import type { ThreeElements } from '@react-three/fiber' export type Api = { selected: THREE.Object3D[] @@ -11,33 +12,39 @@ export type SelectApi = Omit & { enabled?: boolean } -export const selectionContext = /* @__PURE__ */ createContext(null) +export const selectionContext = /* @__PURE__ */ createContext({ + select: () => {}, + enabled: true, + selected: [] +}) export function Selection({ children, enabled = true }: { enabled?: boolean; children: React.ReactNode }) { const [selected, select] = useState([]) - const value = useMemo(() => ({ selected, select, enabled }), [selected, select, enabled]) + const value = useMemo(() => ({ selected, select, enabled }), [selected, enabled]) return {children} } export function Select({ enabled = false, children, ...props }: SelectApi) { - const group = useRef(null!) - const api = useContext(selectionContext) + const group = useRef(new THREE.Group()) + const {select = () => {}} = useContext(selectionContext) + useEffect(() => { - if (api && enabled) { - let changed = false - const current: THREE.Object3D[] = [] - group.current.traverse((o) => { - o.type === 'Mesh' && current.push(o) - if (api.selected.indexOf(o) === -1) changed = true - }) - if (changed) { - api.select((state) => [...state, ...current]) - return () => { - api.select((state) => state.filter((selected) => !current.includes(selected))) - } - } + if (!enabled || !group.current) return + + const current: THREE.Object3D[] = [] + group.current.traverse((o) => { + if (o.type === 'Mesh') current.push(o) + }) + + select((prev) => { + const notIncluded = current.filter(obj => !prev.includes(obj)) + return notIncluded.length > 0 ? [...prev, ...notIncluded] : prev + }) + + return () => { + select((prev) => prev.filter(obj => !current.includes(obj))) } - }, [enabled, children, api]) + }, [enabled, select, children]) return ( {children}