一步生成树形结构(自定义工具类实现) 前言
在日常工作中,我们经常会遇到需要生成树形结构的需求,例如:部门树、菜单树等,我们以往的实现方式是写一个递归算法来实现,但是如果这样的需求多了,我们难不成要给每个需求都写一个递归算法来实现吗?显然这是不合理的,我们这样操作会造成很多的冗余代码。那么我们有没有更好的实现思路呢?在这里我分享一种思路,也欢迎大家来一起讨论
实现 思路剖析 我们理想状态是写一个通用的工具类,那么问题来了,到底要咋写?
所以大家别着急,搬个小板凳坐好了,听我给你娓娓道来。接下来我先解答一下大家可能会问的问题。
Q: 你这不对,方法里传入的对象都不一样,你怎么进行对比?
A: 关于这个问题,咱们能用泛型来解决,这样就解决了传入对象不一致导致不能通用的问题。
Q: 那你如果用泛型的话,不就拿不到参数值了吗,你怎么对比?
A: 你别说,这其实也是这个思路的核心,拿小本本记好了:我们可以在调用的时候把需要比对的名字传进去,再通过反射来拿到对应参数的值,拿到值之后我们再进行业务处理即可。
说干就干,兄弟们走,我们去实践一波。
实现 反射工具类:ReflectionUtils 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 import lombok.extern.slf4j.Slf4j;import java.lang.reflect.*;@Slf4j public class ReflectionUtils { public ReflectionUtils () { } private static String getExceptionMessage (String fieldName,Object object) { return "Could not find field [" + fieldName + "] on target [" + object + "]" ; } public static Object getFieldValue (Object object, String fieldName) { Field field = getDeclaredField(object, fieldName); if (field == null ) { throw new IllegalArgumentException (getExceptionMessage(fieldName,object)); } makeAccessible(field); Object result = null ; try { result = field.get(object); } catch (IllegalAccessException e) { log.error("getFieldValue:" , e); } return result; } public static void setFieldValue (Object object, String fieldName, Object value) { Field field = getDeclaredField(object, fieldName); if (field == null ) { throw new IllegalArgumentException (getExceptionMessage(fieldName,object)); } makeAccessible(field); try { field.set(object, value); } catch (IllegalAccessException e) { log.error("setFieldValue:" , e); } } public static Class getSuperClassGenericType (Class clazz, int index) { Type genType = clazz.getGenericSuperclass(); if (!(genType instanceof ParameterizedType)) { return Object.class; } Type[] params = ((ParameterizedType) genType).getActualTypeArguments(); if (index >= params.length || index < 0 ) { return Object.class; } if (!(params[index] instanceof Class)) { return Object.class; } return (Class) params[index]; } @SuppressWarnings("unchecked") public static <T> Class<T> getSuperGenericType (Class clazz) { return getSuperClassGenericType(clazz, 0 ); } public static Method getDeclaredMethod (Object object, String methodName, Class<?>[] parameterTypes) { for (Class<?> superClass = object.getClass(); superClass != Object.class; superClass = superClass.getSuperclass()) { try { return superClass.getDeclaredMethod(methodName, parameterTypes); } catch (NoSuchMethodException e) { } } return null ; } public static void makeAccessible (Field field) { if (!Modifier.isPublic(field.getModifiers())) { field.setAccessible(true ); } } public static Field getDeclaredField (Object object, String filedName) { for (Class<?> superClass = object.getClass(); superClass != Object.class; superClass = superClass.getSuperclass()) { try { return superClass.getDeclaredField(filedName); } catch (NoSuchFieldException e) { } } return null ; } public static Object invokeMethod (Object object, String methodName, Class<?>[] parameterTypes, Object[] parameters) { try { Method method = getDeclaredMethod(object, methodName, parameterTypes); if (method == null ) { throw new IllegalArgumentException ("Could not find method [" + methodName + "] on target [" + object + "]" ); } method.setAccessible(true ); return method.invoke(object, parameters); } catch (Exception e) { log.error("invokeMethod:" , e); } return null ; }
构建树工具类:TreeUtils 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 import cn.hutool.core.bean.BeanUtil;import java.util.List;import java.util.stream.Collectors;public class TreeUtils { private TreeUtils () { } public static <T, R, M> List<R> buildTree (Class<R> clazz, List<T> list, String idName, String parentIdName, String childName, M parentFlag) { List<R> resList = BeanUtil.copyToList(list, clazz); List<R> root = resList.stream() .filter(item -> "0" .equals(String.valueOf(ReflectionUtils.getFieldValue(item, "parentId" )))) .collect(Collectors.toList()); resList.removeAll(root); root.forEach(item -> getChildren(item, resList, idName)); return root; } public static <R> void getChildren (R r, List<R> list, String idName) { if (hasChildren(r, list, idName)) { List<R> collect = list.stream().filter(item -> String.valueOf(ReflectionUtils.getFieldValue(item, "parentId" )) .equals(String.valueOf(ReflectionUtils.getFieldValue(r, idName)))) .collect(Collectors.toList()); if (collect != null && collect.size() > 0 ) { ReflectionUtils.setFieldValue(r, "children" , collect); list.removeAll(collect); collect.forEach(item1 -> getChildren(item1, list, idName)); } } } public static <T> boolean hasChildren (T t, List<T> list, String idName) { return list.stream().anyMatch(item -> { String a = String.valueOf(ReflectionUtils.getFieldValue(item, "parentId" )); String b = String.valueOf(ReflectionUtils.getFieldValue(t, idName)); return a.equals(b); }); } }
调用 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 public class Test { public static void main (String[] args) { List<Dept> list = new ArrayList <>(); Dept dept1 = new Dept (); dept1.setDeptId(1 ); dept1.setParentId(0 ); list.add(dept1); Dept dept2 = new Dept (); dept2.setDeptId(2 ); dept2.setParentId(1 ); list.add(dept2); Dept dept3 = new Dept (); dept3.setDeptId(3 ); dept3.setParentId(1 ); list.add(dept3); Dept dept4 = new Dept (); dept4.setDeptId(4 ); dept4.setParentId(2 ); list.add(dept4); Dept dept5 = new Dept (); dept5.setDeptId(5 ); dept5.setParentId(4 ); list.add(dept5); List<DeptVo> tree = TreeUtils.buildTree(DeptVo.class, list, "deptId" , "parentId" , "children" , 0 ); System.out.println(tree); } }
改进一 思路剖析 虽然我们已经实现了我们需要的功能,但是以字符串的形式来传递名字似乎还是不够优雅,那我们应该做什么样的改进呢?
在我们使用Mybatis Plus的时候,里面提供了一种供lambda表达式传递参数的条件构造器:LambdaQueryWrapper,我们在使用这个lambda条件构造器的时候,需要比对哪个参数,就使用 类名::方法名
的形式即可完成参数传递或方法调用,这种参数传递的方式是不是优雅了很多呢?
其实现方式为使用了SFunction来接收参数,SFunction内部类如下:
::注意:这种写法与常规写法稍有不同,在Java里可以同时继承其他接口和没有方法的接口,这时就可以用英文逗号分隔。::
关于这个类我们就不过多阐述了,具体的可以看我的另一篇文章
封装自定义函数式接口
那我们就仿照它来自定义一个接口
MyFunction
1 2 3 4 5 6 7 8 9 10 11 import java.io.Serializable;import java.util.function.Function;@FunctionalInterface public interface MyFunction <T, R> extends Function <T, R>, Serializable {}
注意:Serializable一定要实现,它的作用是使该对象可序列化和反序列化,如果不实现后面无法对其反序列化。
定义好了之后,我们又如何提供这个参数来获取到传递的属性名呢?我们可以先通过反序列化拿到 SerializedLambda对象
,然后通过这个对象来获取到其内部的 methodName
,例如我们传入的对象是 Dept::getDeptId
,我们获取到的methodName就是 getDeptId
,然后我们期望获取到的是 deptId
,我们可以去截取到第三位,然后将截取后的字符串首字母转小写即可。
ConvertUtils
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 import com.baomidou.mybatisplus.core.toolkit.StringUtils;import com.bummon.lambda.MyFunction;import java.lang.invoke.SerializedLambda;import java.lang.reflect.InvocationTargetException;import java.lang.reflect.Method;import java.util.Map;import java.util.concurrent.ConcurrentHashMap;public class ConvertUtils { public ConvertUtils () { } public static final String GET = "get" ; public static final String IS = "is" ; private static final Map<Class<?>, SerializedLambda> CLASS_LAMBDA_CACHE = new ConcurrentHashMap <>(); private static SerializedLambda getSerializedLambda (MyFunction<?, ?> fn) { SerializedLambda lambda = CLASS_LAMBDA_CACHE.get(fn.getClass()); if (lambda == null ) { try { Method method = fn.getClass().getDeclaredMethod("writeReplace" ); method.setAccessible(Boolean.TRUE); lambda = (SerializedLambda) method.invoke(fn); CLASS_LAMBDA_CACHE.put(fn.getClass(), lambda); } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { throw new RuntimeException (e); } } return lambda; } public static <T, R> String getLambdaFieldName (MyFunction<T, R> fn) { SerializedLambda lambda = getSerializedLambda(fn); String methodName = lambda.getImplMethodName(); if (methodName.startsWith(GET)) { methodName = methodName.substring(3 ); } else if (methodName.startsWith(IS)) { methodName = methodName.substring(2 ); } else { throw new IllegalArgumentException ("无效的getter方法:" + methodName); } return StringUtils.firstToLowerCase(methodName); } }
最后我们调用 getLambdaFieldName
即可获取到我们通过lambda表达式传入的属性名,接下来我们去修改我们工具类中的入参,以及工具类内部方法传递的参数
TreeUtils
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 import cn.hutool.core.bean.BeanUtil;import cn.hutool.core.collection.CollUtil;import com.bummon.util.ConvertUtils;import com.bummon.lambda.MyFunction;import java.util.List;public class TreeUtils { public TreeUtils () { } public static <T, R, M, N> List<N> buildTree (Class<N> clazz, List<T> list, MyFunction<T, R> idNameFunc, MyFunction<T, R> parentIdNameFunc, MyFunction<T, R> childNameFunc, M parentFlag) { String idName = ConvertUtils.getLambdaFieldName(idNameFunc); String parentIdName = ConvertUtils.getLambdaFieldName(parentIdNameFunc); String childName = ConvertUtils.getLambdaFieldName(childNameFunc); List<N> treeList = BeanUtil.copyToList(list, clazz); List<N> root = treeList.stream() .filter(item -> String.valueOf(parentFlag).equals(String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)))) .toList(); treeList.removeAll(root); root.forEach(item -> getChildren(item, treeList, idName, parentIdName, childName)); return root; } public static <T, R, M> List<M> buildTree (Class<M> clazz, MyFunction<T, R> idNameFunc, List<T> list) { String idName = ConvertUtils.getLambdaFieldName(idNameFunc); List<M> treeList = BeanUtil.copyToList(list, clazz); List<M> root = treeList.stream() .filter(item -> "0" .equals(String.valueOf(ReflectionUtils.getFieldValue(item, "parentId" )))) .toList(); treeList.removeAll(root); root.forEach(item -> getChildren(item, treeList, idName, "parentId" , "children" )); return root; } public static <T> void getChildren (T t, List<T> list, String idName, String parentIdName, String childName) { if (hasChildren(t, list, idName, parentIdName)) { List<T> collect = list.stream().filter(item -> String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)) .equals(String.valueOf(ReflectionUtils.getFieldValue(t, idName)))) .toList(); if (CollUtil.isNotEmpty(collect)) { ReflectionUtils.setFieldValue(t, childName, collect); list.removeAll(collect); collect.forEach(item1 -> getChildren(item1, list, idName, parentIdName, childName)); } } } public static <T> boolean hasChildren (T t, List<T> list, String idName, String parentIdName) { return list.stream().anyMatch(item -> { String a = String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)); String b = String.valueOf(ReflectionUtils.getFieldValue(t, idName)); return a.equals(b); }); }
最后我们在调用的时候,只需要如下方式即可调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 public class Test { public static void main (String[] args) { List<Dept> list = new ArrayList <>(); Dept dept1 = new Dept (); dept1.setDeptId(1 ); dept1.setParentId(0 ); list.add(dept1); Dept dept2 = new Dept (); dept2.setDeptId(2 ); dept2.setParentId(1 ); list.add(dept2); Dept dept3 = new Dept (); dept3.setDeptId(3 ); dept3.setParentId(1 ); list.add(dept3); Dept dept4 = new Dept (); dept4.setDeptId(4 ); dept4.setParentId(2 ); list.add(dept4); Dept dept5 = new Dept (); dept5.setDeptId(5 ); dept5.setParentId(4 ); list.add(dept5); List<DeptVo> tree = TreeUtils.buildTree(DeptVo.class, Dept::getDeptId, list); System.out.println(tree); } }
改进二 嘿,你还真别说,上面的调用方式确实优雅了不少,那么还有没有可以更简洁的方式来调用呢?可能大家会问了:这都这么简洁了,哪里还有什么继续简化的地方。
思路剖析 别急嘛,我来给你解答一番,我们在返回树形结构之前,肯定是先查询出来一个列表,然后通过对这个列表的一系列操作,最终形成了一个树形结构。但是如上文中的 children
在实体类中是不存在的,我们还需要再创建一个类来存放这个字段,大家思考一下,如果我们查询出来后可以直接调用,不需要再创建新的类,这样是不是很方便?那具体要如何实现呢?
没错!我们利用Map就可以实现这个需求,我们可以按照之前的思路,先将父节点集合查出来,然后对父节点集合遍历的时候,我们把父节点转换为一个Map对象,最后再把递归出来的子节点放入到这个Map对象中,并将key设置为children。理论存在,实践开始!
改造TreeUtils
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 import cn.hutool.core.bean.BeanUtil;import cn.hutool.core.collection.CollUtil;import com.bummon.lambda.MyFunction;import com.bummon.util.ConvertUtils;import com.bummon.util.tree.ReflectionUtils;import java.util.Collections;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.concurrent.atomic.AtomicInteger;public class TestTreeUtils { private static final String CHILD_NAME = "children" ; public static <T, R, M> Map<String, Object> buildTree (List<T> list, MyFunction<T, R> idNameFunc, MyFunction<T, R> parentIdNameFunc, M parentFlag) { String idName = ConvertUtils.getLambdaFieldName(idNameFunc); String parentIdName = ConvertUtils.getLambdaFieldName(parentIdNameFunc); List<T> root = list.stream() .filter(item -> String.valueOf(parentFlag).equals(String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)))) .toList(); list.removeAll(root); Map<String, Object> map = new HashMap <>(); AtomicInteger index = new AtomicInteger (0 ); root.forEach(item -> { Map<String, Object> itemMap = BeanUtil.beanToMap(item); Map<String, Object> children = getChildren(itemMap, list, idName, parentIdName); itemMap.put(CHILD_NAME, children); map.put(String.valueOf(index.get()), itemMap); index.getAndIncrement(); }); return map; } public static <T, R> Map<String, Object> buildTree (List<T> list, MyFunction<T, R> idNameFunc) { String idName = ConvertUtils.getLambdaFieldName(idNameFunc); String parentIdName = "parentId" ; String parentFlag = "0" ; List<T> root = list.stream() .filter(item -> parentFlag.equals(String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)))) .toList(); list.removeAll(root); Map<String, Object> map = new HashMap <>(); AtomicInteger index = new AtomicInteger (0 ); root.forEach(item -> { Map<String, Object> itemMap = BeanUtil.beanToMap(item); Map<String, Object> children = getChildren(itemMap, list, idName, parentIdName); itemMap.put(CHILD_NAME, children); map.put(String.valueOf(index.get()), itemMap); index.getAndIncrement(); }); return map; } public static <T> Map<String, Object> getChildren (Map<String, Object> itemMap, List<T> list, String idName, String parentIdName) { if (hasChildren(itemMap, list, idName, parentIdName)) { List<T> collect = list.stream().filter(item -> String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)) .equals(String.valueOf(itemMap.get(idName)))) .toList(); Map<String, Object> map = new HashMap <>(); if (CollUtil.isNotEmpty(collect)) { itemMap.put(CHILD_NAME, collect); list.removeAll(collect); AtomicInteger index = new AtomicInteger (0 ); collect.forEach(item -> { Map<String, Object> childItemMap = BeanUtil.beanToMap(item); Map<String, Object> children = getChildren(childItemMap, list, idName, parentIdName); childItemMap.put(CHILD_NAME, children); map.put(String.valueOf(index.get()), childItemMap); index.getAndIncrement(); }); } return map; } return Collections.emptyMap(); } public static <T> boolean hasChildren (Map<String, Object> itemMap, List<T> list, String idName, String parentIdName) { return list.stream().anyMatch(item -> { String a = String.valueOf(ReflectionUtils.getFieldValue(item, parentIdName)); String b = String.valueOf(itemMap.get(idName)); return a.equals(b); }); } }
调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 public class Test { public static void main (String[] args) { List<Dept> list = new ArrayList <>(); Dept dept1 = new Dept (); dept1.setDeptId(1 ); dept1.setParentId(0 ); list.add(dept1); Dept dept2 = new Dept (); dept2.setDeptId(2 ); dept2.setParentId(1 ); list.add(dept2); Dept dept3 = new Dept (); dept3.setDeptId(3 ); dept3.setParentId(1 ); list.add(dept3); Dept dept4 = new Dept (); dept4.setDeptId(4 ); dept4.setParentId(2 ); list.add(dept4); Dept dept5 = new Dept (); dept5.setDeptId(5 ); dept5.setParentId(4 ); list.add(dept5); List<DeptVo> tree = TreeUtils.buildTree(list, Dept::getDeptId); System.out.println(tree); } }
至此,我们的工具类就大功告成了,该方法也可以用于二次封装Hutool中的构建树形结构的工具类。作者学艺不精,在这里只是分享一种思路,大家也可以根据自己的需求来更改相关代码!