@@ -6,6 +6,15 @@ import {FlattenedNode} from './shapes/nodeShapes';
6
6
import TreeState , { State } from './state/TreeState' ;
7
7
8
8
export default class Tree extends React . Component {
9
+ constructor ( props ) {
10
+ super ( props ) ;
11
+ this . state = {
12
+ stickyHeaders : [ ] , // To track all visible group headers
13
+ topStickyHeader : null , // The header that should be sticky
14
+ } ;
15
+ this . _listRef = React . createRef ( ) ;
16
+ }
17
+
9
18
_cache = new CellMeasurerCache ( {
10
19
fixedWidth : true ,
11
20
minHeight : 20 ,
@@ -35,8 +44,101 @@ export default class Tree extends React.Component {
35
44
: nodes [ index ] ;
36
45
} ;
37
46
47
+ // Determine if a node is a group header
48
+ isGroupHeader = node => {
49
+ // Group headers are typically parent nodes with children
50
+ // and deepness of 0 (root level)
51
+ return node . children && node . children . length > 0 && node . deepness === 0 ;
52
+ } ;
53
+
54
+ componentDidMount ( ) {
55
+ // Initial check for headers after mounting
56
+ if ( this . _listRef . current ) {
57
+ const list = this . _listRef . current ;
58
+ const grid = list && list . Grid ;
59
+ if ( grid ) {
60
+ this . handleScroll ( {
61
+ scrollTop : grid . state . scrollTop ,
62
+ scrollHeight : grid . state . scrollHeight ,
63
+ clientHeight : grid . props . height ,
64
+ } ) ;
65
+ }
66
+ }
67
+ }
68
+
69
+ // Get all headers in the current data
70
+ getAllHeaders = ( ) => {
71
+ const rowCount = this . getRowCount ( ) ;
72
+ const headers = [ ] ;
73
+
74
+ for ( let i = 0 ; i < rowCount ; i ++ ) {
75
+ const node = this . getNode ( i ) ;
76
+ if ( this . isGroupHeader ( node ) ) {
77
+ // Calculate the position by summing heights of all rows before this one
78
+ let top = 0 ;
79
+ for ( let j = 0 ; j < i ; j ++ ) {
80
+ top += this . _cache . rowHeight ( { index : j } ) ;
81
+ }
82
+
83
+ headers . push ( {
84
+ node,
85
+ index : i ,
86
+ top,
87
+ } ) ;
88
+ }
89
+ }
90
+
91
+ return headers ;
92
+ } ;
93
+
94
+ // Handle scroll events to update sticky headers
95
+ handleScroll = ( { scrollTop, scrollHeight, clientHeight} ) => {
96
+ if ( ! this . _listRef . current ) return ;
97
+
98
+ // Get all headers in the tree
99
+ const allHeaders = this . getAllHeaders ( ) ;
100
+
101
+ // Find headers that should be visible based on scroll position
102
+ const visibleHeaders = allHeaders . filter ( header => {
103
+ // Calculate the bottom position of this header row
104
+ const headerHeight = this . _cache . rowHeight ( { index : header . index } ) ;
105
+ const headerBottom = header . top + headerHeight ;
106
+
107
+ // Header is visible if:
108
+ // 1. Its top is between scrollTop and scrollTop + clientHeight, OR
109
+ // 2. Its bottom is between scrollTop and scrollTop + clientHeight, OR
110
+ // 3. It starts before scrollTop and ends after scrollTop + clientHeight
111
+ return (
112
+ ( header . top >= scrollTop && header . top <= scrollTop + clientHeight ) ||
113
+ ( headerBottom >= scrollTop && headerBottom <= scrollTop + clientHeight ) ||
114
+ ( header . top <= scrollTop && headerBottom >= scrollTop + clientHeight )
115
+ ) ;
116
+ } ) ;
117
+
118
+ // Find the header that should be sticky
119
+ // It's the last header whose top position is less than or equal to scrollTop
120
+ const headersBeforeViewport = allHeaders . filter ( h => h . top <= scrollTop ) ;
121
+ const topStickyHeader =
122
+ headersBeforeViewport . length > 0 ? headersBeforeViewport [ headersBeforeViewport . length - 1 ] : null ;
123
+
124
+ // Only update state if something has changed
125
+ const currentStickyId = this . state . topStickyHeader && this . state . topStickyHeader . node && this . state . topStickyHeader . node . id ;
126
+ const newStickyId = topStickyHeader && topStickyHeader . node && topStickyHeader . node . id ;
127
+
128
+ if ( currentStickyId !== newStickyId || this . state . stickyHeaders . length !== visibleHeaders . length ) {
129
+ this . setState ( {
130
+ stickyHeaders : visibleHeaders ,
131
+ topStickyHeader,
132
+ } ) ;
133
+ }
134
+ } ;
135
+
38
136
rowRenderer = ( { node, key, measure, style, NodeRenderer, index} ) => {
39
137
const { nodeMarginLeft} = this . props ;
138
+ const isHeader = this . isGroupHeader ( node ) ;
139
+
140
+ // Add a class to identify group headers
141
+ const className = isHeader ? 'tree-group-header' : '' ;
40
142
41
143
return (
42
144
< NodeRenderer
@@ -47,14 +149,49 @@ export default class Tree extends React.Component {
47
149
userSelect : 'none' ,
48
150
cursor : 'pointer' ,
49
151
} }
152
+ className = { className }
50
153
node = { node }
51
154
onChange = { this . props . onChange }
52
155
measure = { measure }
53
156
index = { index }
157
+ isGroupHeader = { isHeader }
54
158
/>
55
159
) ;
56
160
} ;
57
161
162
+ // Render the sticky header
163
+ renderStickyHeader = ( ) => {
164
+ const { topStickyHeader} = this . state ;
165
+ if ( ! topStickyHeader ) return null ;
166
+
167
+ const { NodeRenderer, nodeMarginLeft} = this . props ;
168
+ // Always use the current node from the tree to ensure we have the latest state
169
+ const index = topStickyHeader . index ;
170
+ const currentNode = this . getNode ( index ) ;
171
+
172
+ return (
173
+ < div className = "tree-sticky-header" >
174
+ < NodeRenderer
175
+ key = { `sticky-header-${ currentNode . id } ` }
176
+ style = { {
177
+ marginLeft : currentNode . deepness * nodeMarginLeft ,
178
+ userSelect : 'none' ,
179
+ cursor : 'pointer' ,
180
+ width : '100%' ,
181
+ background : '#fff' , // Background to ensure visibility
182
+ zIndex : 10 ,
183
+ } }
184
+ className = "tree-group-header tree-sticky"
185
+ node = { currentNode }
186
+ onChange = { this . props . onChange }
187
+ index = { index }
188
+ isGroupHeader = { true }
189
+ isSticky = { true }
190
+ />
191
+ </ div >
192
+ ) ;
193
+ } ;
194
+
58
195
measureRowRenderer = nodes => ( { key, index, style, parent} ) => {
59
196
const { NodeRenderer} = this . props ;
60
197
const node = this . getNode ( index ) ;
@@ -66,25 +203,64 @@ export default class Tree extends React.Component {
66
203
) ;
67
204
} ;
68
205
206
+ componentDidUpdate ( prevProps ) {
207
+ // If nodes change, reset the cache
208
+ if ( prevProps . nodes !== this . props . nodes ) {
209
+ this . _cache . clearAll ( ) ;
210
+ if ( this . _listRef . current ) {
211
+ this . _listRef . current . recomputeRowHeights ( ) ;
212
+ }
213
+
214
+ // Force rerender of sticky header when nodes change
215
+ this . forceUpdate ( ) ;
216
+ }
217
+ }
218
+
69
219
render ( ) {
70
220
const { nodes, width, scrollToIndex, scrollToAlignment} = this . props ;
221
+ const { topStickyHeader} = this . state ;
222
+
223
+ // Calculate the height of the sticky header to properly offset the list
224
+ const stickyHeaderHeight = topStickyHeader ? this . _cache . rowHeight ( { index : topStickyHeader . index } ) : 0 ;
71
225
72
226
return (
73
- < AutoSizer disableWidth = { Boolean ( width ) } >
74
- { ( { height, width : autoWidth } ) => (
75
- < List
76
- deferredMeasurementCache = { this . _cache }
77
- ref = { r => ( this . _list = r ) }
78
- height = { height }
79
- rowCount = { this . getRowCount ( ) }
80
- rowHeight = { this . _cache . rowHeight }
81
- rowRenderer = { this . measureRowRenderer ( nodes ) }
82
- width = { width || autoWidth }
83
- scrollToIndex = { scrollToIndex }
84
- scrollToAlignment = { scrollToAlignment }
85
- />
227
+ < div className = "tree-container" style = { { position : 'relative' , height : '100%' } } >
228
+ { /* Sticky header container */ }
229
+ { topStickyHeader && (
230
+ < div
231
+ className = "tree-sticky-header-container"
232
+ style = { {
233
+ position : 'absolute' ,
234
+ top : 0 ,
235
+ left : 0 ,
236
+ right : 0 ,
237
+ zIndex : 100 ,
238
+ height : `${ stickyHeaderHeight } px` ,
239
+ } }
240
+ >
241
+ { this . renderStickyHeader ( ) }
242
+ </ div >
86
243
) }
87
- </ AutoSizer >
244
+
245
+ < AutoSizer disableWidth = { Boolean ( width ) } >
246
+ { ( { height, width : autoWidth } ) => (
247
+ < List
248
+ deferredMeasurementCache = { this . _cache }
249
+ ref = { this . _listRef }
250
+ height = { height }
251
+ rowCount = { this . getRowCount ( ) }
252
+ rowHeight = { this . _cache . rowHeight }
253
+ rowRenderer = { this . measureRowRenderer ( nodes ) }
254
+ width = { width || autoWidth }
255
+ scrollToIndex = { scrollToIndex }
256
+ scrollToAlignment = { scrollToAlignment }
257
+ onScroll = { this . handleScroll }
258
+ // Important: adds overscan to ensure we load enough rows to find headers
259
+ overscanRowCount = { 20 }
260
+ />
261
+ ) }
262
+ </ AutoSizer >
263
+ </ div >
88
264
) ;
89
265
}
90
266
}
0 commit comments