Commit in lcsim/src/org/lcsim/contrib/uiowa on MAIN
HandleMultiTrackClusters.java+271added 1.1
MJC: Crude code to split clusters with multiple tracks

lcsim/src/org/lcsim/contrib/uiowa
HandleMultiTrackClusters.java added at 1.1
diff -N HandleMultiTrackClusters.java
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ HandleMultiTrackClusters.java	10 May 2007 02:04:25 -0000	1.1
@@ -0,0 +1,271 @@
+package org.lcsim.contrib.uiowa;
+
+import java.util.*; 
+import org.lcsim.util.*;
+import org.lcsim.event.*;
+import org.lcsim.event.base.*;
+import org.lcsim.event.util.*;
+import org.lcsim.recon.cluster.util.*;
+import org.lcsim.mc.fast.tracking.ReconTrack;
+import hep.physics.vec.*;
+
+/**
+ * This module takes as input from the event a list of
+ * ReconstructedParticle objects. It scans these for cases where
+ * a single particle has more than one track attached to it.
+ * For those cases, it splits them into new particles, one
+ * per track.
+ *
+ * The output is a list of ReconstructedParticle objects,
+ * written to the event. This list includes (unmodified) all of
+ * the input particles with 0 or 1 track, plus the new, split-up
+ * particles.
+ */
+
+public class HandleMultiTrackClusters extends Driver
+{
+    protected String m_inputParticleListName;
+    protected String m_outputParticleListName;
+
+    /** Convenience constructor. */
+    public HandleMultiTrackClusters(String inputParticleList, String outputParticleList) {
+	m_inputParticleListName = inputParticleList;
+	m_outputParticleListName = outputParticleList;
+    }
+
+    public void process(EventHeader event) {
+	super.process(event);
+	// Read in list of particles
+	List<ReconstructedParticle> inputParticleList = event.get(ReconstructedParticle.class, m_inputParticleListName);
+	List<ReconstructedParticle> outputParticleList = new Vector<ReconstructedParticle>();
+	// Loop over them, looking for multi-track particles
+	for (ReconstructedParticle part : inputParticleList) {
+	    if (part.getTracks().size()<2) {
+		// Don't change
+		outputParticleList.add(part);
+	    } else {
+		// Split into one particle per track
+		// Want to split like this:
+		//   * Find where the clusters start (track calorimeter entry points)
+		//   * Add pieces to each cluster
+		// For now do it in a fairly dumb way.
+		LocalHelixExtrapolationTrackClusterMatcher tmpExtrap = new LocalHelixExtrapolationTrackClusterMatcher();
+		tmpExtrap.process(event);
+		List<Cluster> unmatchedClusters = makeFlatClusterList(part);
+		Map<Track, Cluster> coreMap = new HashMap<Track, Cluster>();
+		Map<Track, List<Cluster>> fullMap = new HashMap<Track, List<Cluster>>();
+		for (Track tr : part.getTracks()) {
+		    Cluster core = tmpExtrap.matchTrackToCluster(tr, unmatchedClusters);
+		    if (core == null) { throw new AssertionError("Help: didn't find a cluster match"); }
+		    coreMap.put(tr, core);
+		    fullMap.put(tr, new Vector<Cluster>());
+		    fullMap.get(tr).add(core);
+		    unmatchedClusters.remove(core);
+		}
+		while (unmatchedClusters.size() > 0) {
+		    findNextMatch(coreMap, fullMap, unmatchedClusters);
+		}
+		for (Track tr : part.getTracks()) {
+		    BasicCluster masterCluster = new BasicCluster();
+		    for (Cluster clus : fullMap.get(tr)) {
+			masterCluster.addCluster(clus);
+		    }
+		    double trueMass = ((ReconTrack)(tr)).getMCParticle().getMass();
+		    Hep3Vector trackMomentum = new BasicHep3Vector(tr.getMomentum());
+		    double energy = Math.sqrt(trackMomentum.magnitudeSquared() + trueMass*trueMass);
+		    BaseReconstructedParticle outputParticle = new BaseReconstructedParticle(energy, trackMomentum);
+		    outputParticle.addTrack(tr);
+		    outputParticle.addCluster(masterCluster);
+		    outputParticleList.add(outputParticle);
+		}
+	    }
+	}
+	event.put(m_outputParticleListName, outputParticleList, ReconstructedParticle.class, 0);
+    }
+
+    void findNextMatch(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters)  {
+	// Go round iteratively, first with tight cuts and then progressively looser...
+	double cutDotProduct = 0.8;
+	double cutDisplacementMagnitude = 30.0;
+	double cutRadius = 13.0;
+	while (unmatchedClusters.size() > 0) {
+	    int numberRemoved = findNextMatch_cone(coreMap, fullMap, unmatchedClusters, cutDotProduct, cutDisplacementMagnitude);
+	    if (numberRemoved > 0) { continue; }
+	    numberRemoved = findNextMatch_radius(coreMap, fullMap, unmatchedClusters, cutRadius);
+	    if (numberRemoved > 2) { continue; }
+	    // OK, need to back off. 
+	    cutDisplacementMagnitude += 10.0;
+	    cutRadius += 5.0;
+	}
+    }
+
+    int findNextMatch_cone(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters, double cutDotProduct, double cutDisplacementMagnitude) {
+
+	Map<Cluster, List<Track>> provisionalAssignments = new HashMap<Cluster, List<Track>>();
+	for (Track tr : coreMap.keySet()) {
+	    Collection<Cluster> existingClusters = fullMap.get(tr);
+	    for (Cluster seed : existingClusters) {
+		Hep3Vector seedPosition = estimatePosition(seed);
+		for (Cluster cand : unmatchedClusters) {
+		    Hep3Vector candPosition = estimatePosition(cand);
+		    Hep3Vector displacement = VecOp.sub(candPosition, seedPosition);
+		    double dotProduct = VecOp.dot(VecOp.unit(seedPosition), VecOp.unit(displacement));
+		    double displacementMagnitude = displacement.magnitude();
+		    if (dotProduct > cutDotProduct && displacementMagnitude < cutDisplacementMagnitude) {
+			// Match!
+			if ( ! provisionalAssignments.keySet().contains(cand) ) {
+			    provisionalAssignments.put(cand, new Vector<Track>());
+			}
+			provisionalAssignments.get(cand).add(tr);
+		    }
+		}
+	    }
+	}
+	
+	// OK. Any uniquely matched clusters can be added.
+	// First, handle non-ambiguous cases
+	int countMatches = 0;
+	for (Cluster cand : provisionalAssignments.keySet()) {
+	    List<Track> trackMatches = provisionalAssignments.get(cand);
+	    if (trackMatches.size()<1) { 
+		throw new AssertionError("bug"); 
+	    } else if (trackMatches.size()==1) {
+		Track tr = trackMatches.get(0);
+		fullMap.get(tr).add(cand);
+		unmatchedClusters.remove(cand);
+		countMatches++;
+	    }
+	}
+	// Now handle ambiguous cases:
+	for (Cluster cand : provisionalAssignments.keySet()) {
+	    List<Track> trackMatches = provisionalAssignments.get(cand);
+	    if (trackMatches.size()>1) { 
+		// Choose closest hit-hit match
+		Track bestMatchTrack = null;
+		double bestMatchDistance = 0;
+		for (Track trackCand : trackMatches) {
+		    for (Cluster trackClus : fullMap.get(trackCand)) {
+			double currentDist = proximity(trackClus, cand);
+			if (bestMatchTrack == null || currentDist < bestMatchDistance) {
+			    bestMatchTrack = trackCand;
+			    bestMatchDistance = currentDist;
+			}
+		    }
+		}
+		fullMap.get(bestMatchTrack).add(cand);
+		unmatchedClusters.remove(cand);
+		countMatches++;
+	    }
+	}
+
+	return countMatches;
+    }
+
+    int findNextMatch_radius(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters, double cutRadius) {
+	// Assign all hits within a short radius.
+	// If that fails to produce anything, look for single next match.
+	Track bestMatchTrack = null;
+	Cluster bestMatchCluster = null;
+	double bestMatchDistance = 0;
+	Map<Cluster,Track> pendingMatches = new HashMap<Cluster,Track>();
+	for (Cluster cand : unmatchedClusters) {
+	    for (Track tr : coreMap.keySet()) {
+		for (Cluster trackClus : fullMap.get(tr)) {
+		    double currentDist = proximity(trackClus, cand);
+		    if (bestMatchTrack == null || currentDist < bestMatchDistance) {
+			bestMatchTrack = tr;
+			bestMatchCluster = cand;
+			bestMatchDistance = currentDist;
+		    }
+		    if (currentDist < cutRadius) {
+			pendingMatches.put(cand,tr);
+		    }
+		}
+	    }
+	}
+	if (pendingMatches.size()>0) {
+	    for (Cluster clus : pendingMatches.keySet()) {
+		Track matchedTrack = pendingMatches.get(clus);
+		unmatchedClusters.remove(clus);
+		fullMap.get(matchedTrack).add(clus);
+	    }
+	    return pendingMatches.keySet().size();
+	} else {
+	    if (bestMatchTrack==null || bestMatchCluster==null) { throw new AssertionError("stuck"); }
+	    unmatchedClusters.remove(bestMatchCluster);
+	    fullMap.get(bestMatchTrack).add(bestMatchCluster);
+	    return 1;
+	}
+    }
+
+    List<Cluster> makeFlatClusterList(ReconstructedParticle part) {
+	Set<CalorimeterHit> allHits = new HashSet<CalorimeterHit>();
+	Set<Cluster> containedClusters = new HashSet<Cluster>();
+	for (Cluster clus : part.getClusters()) {
+	    subMakeFlatClusterList(clus, allHits, containedClusters);
+	}
+	List<Cluster> allClusters = new Vector<Cluster>();
+	Set<CalorimeterHit> countedHits = new HashSet<CalorimeterHit>();
+	for (Cluster clus : containedClusters) {
+	    allClusters.add(clus);
+	    countedHits.addAll(clus.getCalorimeterHits());
+	}
+	// Any excess?
+	for (CalorimeterHit hit : allHits) {
+	    if ( ! countedHits.contains(hit) ) {
+		BasicCluster newClus = new BasicCluster();
+		newClus.addHit(hit);
+		allClusters.add(newClus);
+		countedHits.add(hit);
+	    }
+	}
+	// Crosscheck
+	List<CalorimeterHit> checkHits = new Vector<CalorimeterHit>();
+	for (Cluster clus : allClusters) {
+	    checkHits.addAll(clus.getCalorimeterHits());
+	}
+	return allClusters;
+    }
+    void subMakeFlatClusterList(Cluster clus, Set<CalorimeterHit> allHits, Set<Cluster> containedClusters) {
+	allHits.addAll(clus.getCalorimeterHits());
+	if (clus.getClusters().size() == 0) {
+	    containedClusters.add(clus);
+	} else {
+	    for (Cluster subClus : clus.getClusters()) {
+		subMakeFlatClusterList(subClus, allHits, containedClusters);
+	    }
+	}
+    }
+
+    protected double proximity(Cluster clus1, Cluster clus2) {
+	if (clus1.getCalorimeterHits().size()<1) { throw new AssertionError("Empty cluster"); }
+	if (clus2.getCalorimeterHits().size()<1) { throw new AssertionError("Empty cluster"); }
+	double minDist = 0;
+	boolean found = false;
+	for (CalorimeterHit hit1 : clus1.getCalorimeterHits()) {
+	    Hep3Vector hitPosition1 = new BasicHep3Vector(hit1.getPosition());
+	    for (CalorimeterHit hit2 : clus2.getCalorimeterHits()) {
+		Hep3Vector hitPosition2 = new BasicHep3Vector(hit2.getPosition());
+		double distance = VecOp.sub(hitPosition1,hitPosition2).magnitude();
+		if (distance<minDist || found==false) {
+		    found = true;
+		    minDist = distance;
+		}
+	    }
+	}
+	return minDist;
+    }
+
+    protected Hep3Vector estimatePosition(Cluster clus) {
+	double[] runningSum = new double[3];
+	double count = clus.getCalorimeterHits().size();
+	for (CalorimeterHit hit : clus.getCalorimeterHits()) {
+	    double[] hitPos = hit.getPosition();
+	    for (int i=0; i<3; i++) {
+		runningSum[i] += (hitPos[i] / count);
+	    }
+	}
+	return new BasicHep3Vector(runningSum);
+    }
+    
+}
CVSspam 0.2.8