lcsim/src/org/lcsim/contrib/JanStrube/vtxFitter
diff -u -r1.1 -r1.2
--- Fitter.java 26 Jun 2006 02:54:31 -0000 1.1
+++ Fitter.java 27 Jun 2006 01:52:30 -0000 1.2
@@ -1,6 +1,7 @@
package vtxFitter;
import org.lcsim.event.ReconstructedParticle;
+import org.lcsim.event.Track;
import Jama.Matrix;
@@ -20,8 +21,6 @@
Matrix D_k, D_prev, D;
//covariance for x and q
Matrix E_k, E_prev, E;
- // track measrement
- double[] m_k, m_prev, m;
// covariance for track measurement
Matrix V_k, V_prev, V;
// inverse of covariance matrix
@@ -38,6 +37,11 @@
// residual
double[] c_k0;
+ // chi2
+ double chi2, chi2_prev;
+
+ double[] r_k;
+
/**
* Definitions:
* x_k = estimate of the vertex position after using the information of k tracks
@@ -65,37 +69,63 @@
C = new Matrix(C0);
}
- private void filter() {
- x = multiply(C_k
- , plus(multiply(C_prev.inverse(), x_prev)
- , multiply(A_k.transpose().times(G_B)
- , minus(m_k, c_k0))));
+
+ public void fit() {
+ for (Track t : particle.getTracks()) {
+ filter(t);
+ }
+ for (Track t : particle.getTracks()) {
+ smoothe(t);
+ }
+ }
+ private void filter(Track t) {
+ double[] m_k = t.getTrackParameters();
+ C_k = C_prev.inverse().plus(A_k.transpose().times(G_B.times(A_k))).inverse();
+
+ x = C_k.times(add(C_prev.inverse().times(x_prev)
+ , A_k.transpose().times(G_B).times(subtract(m_k, c_k0))));
- q = multiply(W_k.times(B_k.transpose().times(G_k)), minus(m_k, plus(c_k0, multiply(A_k, x_k))));
- C = C_prev.inverse().plus(A_k.transpose().times(G_B.times(A_k))).inverse();
+ q = W_k.times(B_k.transpose().times(G_k)).times(subtract(m_k, add(c_k0, A_k.times(x_k))));
D = W_k.plus(W_k.times(B_k.transpose().plus(G_k.times(A_k.times(C_k.times(A_k.transpose().times(G_k.times(B_k.times(W_k)))))))));
- E = C_k.times(A_k.transpose().times(G_k.times(B_k.times(W_k)))).times(-1);
+ E = C_k.times(A_k.transpose().times(G_k.times(B_k.times(W_k)))).uminus();
W_k = (B_k.transpose().times(G_k.times(B_k))).times(-1);
G_B = G_k.minus((G_k.times(B_k.times(W_k.times(B_k.transpose().times(G_k))))));
- }
-
- // returns y = A.x
- static double[] multiply(Matrix A, double[] x) {
- if (A.getColumnDimension() != x.length)
- throw new IllegalArgumentException("dimensions do not match");
- double[] result = new double[A.getRowDimension()];
+ chi2 = chi2_prev + dot(r_k, G_k.times(r_k)) + dot(subtract(x_k, x_prev), C_prev.inverse().times(subtract(x_k, x_prev)));
- for (int i=0; i<A.getRowDimension(); i++) {
- for (int j=0; j<A.getColumnDimension(); j++) {
- result[i] += A.get(i, j) * x[j];
- }
- }
- return result;
+ r_k = subtract(m_k, add(c_k0, add(A_k.times(x_k), B_k.times(q_k))));
}
- static double[] minus(double[] a, double[] b) {
+ private void smoothe(Track t) {
+ double[] m_k = t.getTrackParameters();
+ double[] x_N = x_k;
+ double[] q_kN = W_k.times(B_k.transpose().times(G_k)).times(subtract(m_k, add(c_k0, A_k.times(x_N))));
+ Matrix C_kN = C_k;
+ Matrix D_kN = W_k.plus(W_k.times(B_k.transpose().times(G_k.times(A_k.times(C_kN.times(A_k.transpose().times(G_k.times(B_k.times(W_k)))))))));
+ Matrix E_kN = C_kN.times(A_k.transpose().times(G_k.times(B_k.times(W_k)))).uminus();
+ }
+
+// // returns y = A.x
+// static double[] multiply(Matrix A, double[] x) {
+// if (A.getColumnDimension() != x.length)
+// throw new IllegalArgumentException("dimensions do not match");
+// double[] result = new double[A.getRowDimension()];
+//
+// for (int i=0; i<A.getRowDimension(); i++) {
+// for (int j=0; j<A.getColumnDimension(); j++) {
+// result[i] += A.get(i, j) * x[j];
+// }
+// }
+// return result;
+// }
+//
+// // for notational convenience
+// static Matrix multiply(Matrix a, Matrix b) {
+// return a.times(b);
+// }
+
+ static double[] subtract(double[] a, double[] b) {
if (a.length != b.length)
throw new IllegalArgumentException("dimensions do not match");
double[] result = new double[a.length];
@@ -105,7 +135,7 @@
return result;
}
- static double[] plus(double[] a, double[] b) {
+ static double[] add(double[] a, double[] b) {
if (a.length != b.length)
throw new IllegalArgumentException("dimensions do not match");
double[] result = new double[a.length];
@@ -114,4 +144,19 @@
}
return result;
}
+
+ // for notational convenience
+ static Matrix add(Matrix a, Matrix b) {
+ return a.plus(b);
+ }
+
+ static double dot(double[] a, double[] b) {
+ if (a.length != b.length)
+ throw new IllegalArgumentException("dimensions don't match");
+ double result = 0;
+ for (int i=0; i<a.length; i++)
+ result += a[i]*b[i];
+ return result;
+ }
+
}
lcsim/src/org/lcsim/contrib/JanStrube/vtxFitter
diff -u -r1.1 -r1.2
--- FitterTest.java 26 Jun 2006 02:54:31 -0000 1.1
+++ FitterTest.java 27 Jun 2006 01:52:30 -0000 1.2
@@ -2,9 +2,10 @@
import Jama.Matrix;
import junit.framework.TestCase;
-import static vtxFitter.Fitter.minus;
-import static vtxFitter.Fitter.plus;
-import static vtxFitter.Fitter.multiply;
+import static vtxFitter.Fitter.subtract;
+import static vtxFitter.Fitter.add;
+//import static vtxFitter.Fitter.multiply;
+import static vtxFitter.Fitter.dot;
public class FitterTest extends TestCase {
Matrix A;
@@ -21,21 +22,21 @@
super.tearDown();
}
- /*
- * Test method for 'vtxFitter.Fitter.multiply(Matrix, double[])'
- */
- public void testMultiply() {
- double[] r = multiply(A, x);
- assertEquals(r.length, A.getRowDimension());
- assertEquals(r[0], 15.);
- assertEquals(r[1], 15.);
- }
+// /*
+// * Test method for 'vtxFitter.Fitter.multiply(Matrix, double[])'
+// */
+// public void testMultiply() {
+// double[] r = multiply(A, x);
+// assertEquals(r.length, A.getRowDimension());
+// assertEquals(r[0], 15.);
+// assertEquals(r[1], 15.);
+// }
/*
* Test method for 'vtxFitter.Fitter.minus(double[], double[])'
*/
public void testMinus() {
- double[] r = minus(x, y);
+ double[] r = subtract(x, y);
assertEquals(r.length, x.length);
assertEquals(r[0], 0.);
assertEquals(r[1], 0.);
@@ -46,11 +47,15 @@
* Test method for 'vtxFitter.Fitter.plus(double[], double[])'
*/
public void testPlus() {
- double[] r = plus(x, y);
+ double[] r = add(x, y);
assertEquals(r.length, x.length);
assertEquals(r[0], 2.);
assertEquals(r[1], 4.);
assertEquals(r[2], 6.);
}
+ public void testDot() {
+ double r = dot(x, y);
+ assertEquals(r, 14.);
+ }
}