用while循环+堆栈创建
我有点尴尬承认这一点,但我似乎很难被一个简单的编程问题所困扰。 我正在构建一个决策树实现,并一直使用递归来获取标记样本的列表,递归地将列表分成两半,然后将它变成一棵树。
不幸的是,对于深度树,我遇到了堆栈溢出错误(ha!),所以我的第一个想法是使用continuations将其变为尾递归。 不幸的是,Scala不支持这种TCO,所以唯一的解决方案是使用蹦床。 蹦床看起来效率不高,我希望有一些简单的基于堆栈的命令式解决方案来解决这个问题,但是我很难找到它。
递归版本看起来有点像(简化):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
所以基本上,我根据数据的某些特性递归地将列表细分为两部分,并传递一个已使用特性的列表,所以我不重复 - 这些都是在“getSplittingFeature”函数中处理的,因此我们可以忽略它。 代码非常简单! 不过,我很难找出一个基于堆栈的解决方案,它不仅仅使用闭包并且有效地成为蹦床。 我知道我们至少必须在堆栈中保留很少的参数“框架”,但我想避免关闭调用。
我明白,我应该明确地写出了callstack和程序计数器在递归解决方案中隐含地处理了什么,但是我无法在没有延续的情况下这样做。 在这一点上,它甚至不是效率,我只是好奇。 所以,请不要提醒我,过早优化是万恶之源,而基于蹦床的解决方案可能会工作得很好。 我知道它可能会 - 它本身就是一个难题。
任何人都可以告诉我这种基于规范的基于循环和堆栈的解决方案是什么?
更新:基于Thipor Kong出色的解决方案,我编写了一个基于while循环/ stacks / hashtable的算法实现,该算法应该是递归版本的直接转换。 这正是我所期待的:
最后更新:我已经使用了顺序整数索引,以及将所有内容都放回到数组中,而不是用于性能的映射,增加了maxDepth支持,并最终获得了与递归版本相同性能的解决方案(不确定内存使用情况,但是我会猜测更少):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}
正如Wikipedia上所描述的那样,将二叉树存储在数组中:对于节点i
,左边的孩子进入2*i+1
,右边的孩子进入2*i+2
。 当“下”时,你会收集一些待办事项,但仍然需要分解才能到达一片叶子。 一旦你只有叶子,向上(在数组中从右到左)构建决策节点:
更新:清理后的版本,也支持分支中存储的功能(类型参数B),功能更强大/完全纯净,并支持ron建议的具有地图的稀疏树。
Update2-3:经济地使用节点ID的名字空间和抽象的ID类型来允许大树。 从Stream中获取节点ID。
sealed trait DTree[A, B]
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B]
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B]
def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = {
@tailrec
def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) =
todo match {
case Nil => (branches, leafs)
case (a, b, id) :: rest =>
split(a, b) match {
case None =>
goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids)
case Some((left, right, b2)) =>
val leftId #:: rightId #:: idRest = ids
goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest)
}
}
@tailrec
def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] =
branches match {
case Nil => nodes
case (id, b, leftId, rightId) :: rest =>
goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b)))
}
val rootId #:: restIds = ids
val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds)
goUp(branches, leafs)(rootId)
}
// try it out
def split(xs: Seq[Int], b: Int) =
if (xs.size > 1) {
val (left, right) = xs.splitAt(xs.size / 2)
Some((left, right, b + 1))
} else {
None
}
val tree = mktree(0 to 1000, 0, split _, Stream.from(0))
println(tree)
链接地址: http://www.djcxy.com/p/10761.html