golib  0.5
sumproduct/sp.cpp

The sum-product algorithm.

See also
goMaxSum
Parameters
TType of the node values, currently must be integer
TfloatFloating point type, typically goFloat or goDouble
References:
  Bishop, C.M. 
  Pattern Recognition and Machine Learning 
  Springer, 2006 
Author
Christian Gosch
/* Copyright (C) 1998-2011 Christian Gosch, golib at goschs dot de
This file is part of the golib library.
For license regulations, see the file COPYING in the main
directory of the golib source tree. */
#include <gosumproduct.h>
#include <gomaxsum.h>
#include <goautoptr.h>
#include <gofunctor.h>
#include <goplot.h>
class MyFactor : public goFGNodeFactor<goSize_t,float>
{
public:
{
this->ff.set (new goFunctor1<float, MyFactor, const goVector<goSize_t>&> (this, &MyFactor::my_f));
this->setFunctor (&*this->ff);
};
virtual ~MyFactor () {};
float my_f (const goVector<goSize_t>& X)
{
if (X[1] - X[0] == 2)
return 1.0f;
else
return 0.0f;
};
};
int main ()
{
factors.setSize (1);
vars.setSize (3);
vars[0].set (new goFGNodeVariable<goSize_t,goFloat> (2));
vars[1].set (new goFGNodeVariable<goSize_t,goFloat> (2));
vars[2].set (new goFGNodeVariable<goSize_t,goFloat> (1));
factors[0].set (new goFGNodeFactor<goSize_t,goFloat> (1));
vars[0]->value = 0;
vars[1]->value = 1;
vars[2]->value = 2;
factors[0]->value = 3;
fg.connect (vars[0], 0, vars[1], 0);
fg.connect (vars[1], 1, vars[2], 0);
fg.connect (factors[0], 0, vars[0], 1);
// printf ("First test graph:\n");
sp.setValueCount (10);
// sp.run (fg.myVariables[0], fg);
//sp.flooding (fg.myVariables[0], fg);
printf ("\nSecond test graph:\n");
//= Build a second graph to test:
{
vars.setSize (11);
factors.setSize (0);
for (goSize_t i = 0; i < 11; ++i)
{
vars[i].set (new goFGNodeVariable<goSize_t,goFloat> (0));
vars[i]->value = i + 1;
}
//= TODO: Anzahl edges setzen und testen.
vars[0]->adj.setSize (1);
vars[1]->adj.setSize (3);
vars[2]->adj.setSize (4);
vars[3]->adj.setSize (1);
vars[4]->adj.setSize (2);
vars[5]->adj.setSize (1);
vars[6]->adj.setSize (2);
vars[7]->adj.setSize (1);
vars[8]->adj.setSize (2);
vars[9]->adj.setSize (2);
vars[10]->adj.setSize (1);
fg.connect (vars[0], 0, vars[1], 0);
fg.connect (vars[1], 1, vars[8], 0);
fg.connect (vars[1], 2, vars[2], 0);
fg.connect (vars[2], 1, vars[3], 0);
fg.connect (vars[2], 2, vars[4], 0);
fg.connect (vars[2], 3, vars[5], 0);
fg.connect (vars[4], 1, vars[6], 0);
fg.connect (vars[6], 1, vars[7], 0);
fg.connect (vars[8], 1, vars[9], 0);
fg.connect (vars[9], 1, vars[10], 0);
//= Insert a loop
// fg.connect (nodelist(2), nodelist(8));
//nodelist(2)->elem->adj.append (nodelist(8)->elem);
//nodelist(8)->elem->adj.append (nodelist(2)->elem);
sp.run (fg.myVariables[0], fg);
}
#if 1
{
vars.setSize (4);
vars[0].set (new goFGNodeVariable<goSize_t,float> (2));
vars[1].set (new goFGNodeVariable<goSize_t,float> (2));
vars[2].set (new goFGNodeVariable<goSize_t,float> (3));
vars[3].set (new goFGNodeVariable<goSize_t,float> (1));
// factors.setSize (4);
factors.setSize (3);
factors[0].set (new MyFactor(2));
factors[1].set (new MyFactor(2));
factors[2].set (new MyFactor(2));
fg.connect (vars[0], 0, factors[0], 0);
fg.connect (vars[0], 1, factors[1], 0);
fg.connect (vars[1], 0, factors[0], 1);
fg.connect (vars[2], 0, factors[1], 1);
fg.connect (vars[2], 1, factors[2], 0);
fg.connect (vars[3], 0, factors[2], 1);
//= Add a loop and see what happens ...
//factors[3].set (new MyFactor(2));
//fg.connect (vars[1], 1, factors[3], 0);
//fg.connect (vars[2], 2, factors[3], 1);
FILE* f = fopen ("graph2.dot","w");
if (!f)
{
printf ("Could not open graph2.dot for writing.\n");
exit(0);
}
goFGGraphWriteDOT<goSize_t,float> (fg.myVariables[0], f);
fclose (f);
// goSumProduct<goSize_t,float> sp;
// sp.flooding (fg.myVariables[0], fg);
sp.run (fg.myVariables[0], fg);
for (goSize_t i = 0; i < fg.myVariables.getSize(); ++i)
{
goVectorf marginal;
sp.marginal (fg.myVariables[i], sp.getValueCount(), marginal);
goVectorf x (marginal.getSize());
x.fillRange (0.0f, 1.0f, (float)(marginal.getSize()));
goString s = "Marginal var["; s += (int)i; s += "]";
goPlot::plot (x, marginal, s.toCharPtr());
}
ms.setValueCount (10);
printf ("Normal operation:\n");
ms.run (fg.myVariables[0], fg);
printf ("Values: ");
for (goSize_t i = 0; i < fg.myVariables.getSize(); ++i)
{
printf ("%d ", fg.myVariables[i]->value);
}
printf ("\n");
printf ("Flooding:\n");
ms.flooding (fg.myVariables[0], fg, 50);
printf ("Values: ");
for (goSize_t i = 0; i < fg.myVariables.getSize(); ++i)
{
printf ("%d ", fg.myVariables[i]->value);
}
printf ("\n");
}
#endif
//= Make one factor graph just for drawing:
{
fg.myVariables.setSize (3);
fg.myFactors.setSize (3);
fg.myVariables[0].set (new goFGNodeVariable<goSize_t,float> (3));
fg.myVariables[1].set (new goFGNodeVariable<goSize_t,float> (1));
fg.myVariables[2].set (new goFGNodeVariable<goSize_t,float> (2));
fg.myFactors[0].set (new goFGNodeFactor<goSize_t,float> (1));
fg.myFactors[1].set (new goFGNodeFactor<goSize_t,float> (3));
fg.myFactors[2].set (new goFGNodeFactor<goSize_t,float> (2));
fg.connect (fg.myVariables[0], 0, fg.myFactors[0], 0);
fg.connect (fg.myVariables[0], 1, fg.myFactors[1], 0);
fg.connect (fg.myVariables[1], 0, fg.myFactors[1], 1);
fg.connect (fg.myVariables[2], 0, fg.myFactors[1], 2);
fg.connect (fg.myVariables[2], 1, fg.myFactors[2], 0);
fg.connect (fg.myVariables[0], 2, fg.myFactors[2], 1);
FILE* f = fopen ("fg.dot","w");
if (!f)
{
printf ("Could not open fg.dot for writing.\n");
exit(0);
}
goFGGraphWriteDOT<goSize_t,float> (fg.myVariables[0], f);
fclose (f);
}
//= Make a circular factor graph just for drawing:
{
fg.myVariables.setSize (5);
fg.myFactors.setSize (10);
for (goSize_t i = 0; i < 5; ++i)
{
fg.myVariables[i].set (new goFGNodeVariable<goSize_t,float> (3));
}
for (goSize_t i = 0; i < 5; ++i)
{
fg.myFactors[i].set (new goFGNodeFactor<goSize_t,float> (2));
fg.connect (fg.myFactors[i], 0, fg.myVariables[i], 0);
fg.connect (fg.myFactors[i], 1, fg.myVariables[(i+1) % 5], 1);
}
for (goSize_t i = 5; i < 10; ++i)
{
fg.myFactors[i].set (new goFGNodeFactor<goSize_t,float> (1));
fg.connect (fg.myFactors[i], 0, fg.myVariables[i-5], 2);
}
FILE* f = fopen ("fg2.dot","w");
if (!f)
{
printf ("Could not open fg2.dot for writing.\n");
exit(0);
}
goFGGraphWriteDOT<goSize_t,float> (fg.myVariables[0], f);
fclose (f);
}
exit(1);
}
goMaxSum::flooding
bool flooding(goFGNode< T, Tfloat > *startNode, goFactorGraph< T, Tfloat > &fg, goSize_t maxPasses=0)
"Flooding" type scheme for graphs with loops.
Definition: gomaxsum.h:425
goSumProduct
Definition: gosumproduct.h:64
goVector
Definition: gomatrixsignal.h:20
goFunctor1
Member function representation for class members with 1 arguments.
Definition: gofunctor.h:767
goFixedArray::getSize
goSize_t getSize() const
Get the size of the array in number of elements.
Definition: gofixedarray.h:193
goMaxSum
The max-sum algorithm.
Definition: gomaxsum.h:69
goFGNodeVariable
Variable class for goFactorGraph.
Definition: gofactorgraph.h:142
goSize_t
size_t goSize_t
Definition: gotypes.h:96
goMath::Vector< goFloat >
goSumProduct::run
virtual bool run(goFGNode< T, Tfloat > *root, goFactorGraph< T, Tfloat > &fg)
Run the sum-product algorithm.
Definition: gosumproduct.h:84
goAutoPtr
"Smart pointer". Wrapper that automatically deletes its managed pointer when the internal reference c...
Definition: goautoptr.h:127
goFactorGraph::connect
void connect(goFGNode< T, Tfloat > *n1, goSize_t adjIndex1, goFGNode< T, Tfloat > *n2, goSize_t adjIndex2)
Connect n1, "slot" adjIndex1 to n2, "slot" adjIndex2.
Definition: gofactorgraph.h:268
goFixedArray::setSize
void setSize(goSize_t newSize, goSize_t reserve=0, goSize_t resize_overhead=0)
Set the size of the array, deleting old content.
Definition: gofixedarray.h:254
goFixedArray
Array class.
Definition: gofixedarray.h:40
goMaxSum::run
virtual bool run(goFGNode< T, Tfloat > *root, goFactorGraph< T, Tfloat > &fg)
Run the max-sum algorithm.
Definition: gomaxsum.h:84
goPlot::plot
goAutoPtr< goPlot::Graph > plot(const goMatrixf &curve, goAutoPtr< goPlot::Graph > g=0)
Plot a curve given as configuration matrix.
Definition: src/plot/cairoplot.cpp:80
goFactorGraph
Factor graph for use with goMaxSum and goSumProduct.
Definition: gofactorgraph.h:233
goString
String class.
Definition: gostring.h:28
MyFactor
Definition: sp.cpp:13
goFGNodeFactor
Factor node for goFactorGraph.
Definition: gofactorgraph.h:75