lcsim/src/org/lcsim/contrib/uiowa
diff -N HandleMultiTrackClusters.java
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ HandleMultiTrackClusters.java 10 May 2007 02:04:25 -0000 1.1
@@ -0,0 +1,271 @@
+package org.lcsim.contrib.uiowa;
+
+import java.util.*;
+import org.lcsim.util.*;
+import org.lcsim.event.*;
+import org.lcsim.event.base.*;
+import org.lcsim.event.util.*;
+import org.lcsim.recon.cluster.util.*;
+import org.lcsim.mc.fast.tracking.ReconTrack;
+import hep.physics.vec.*;
+
+/**
+ * This module takes as input from the event a list of
+ * ReconstructedParticle objects. It scans these for cases where
+ * a single particle has more than one track attached to it.
+ * For those cases, it splits them into new particles, one
+ * per track.
+ *
+ * The output is a list of ReconstructedParticle objects,
+ * written to the event. This list includes (unmodified) all of
+ * the input particles with 0 or 1 track, plus the new, split-up
+ * particles.
+ */
+
+public class HandleMultiTrackClusters extends Driver
+{
+ protected String m_inputParticleListName;
+ protected String m_outputParticleListName;
+
+ /** Convenience constructor. */
+ public HandleMultiTrackClusters(String inputParticleList, String outputParticleList) {
+ m_inputParticleListName = inputParticleList;
+ m_outputParticleListName = outputParticleList;
+ }
+
+ public void process(EventHeader event) {
+ super.process(event);
+ // Read in list of particles
+ List<ReconstructedParticle> inputParticleList = event.get(ReconstructedParticle.class, m_inputParticleListName);
+ List<ReconstructedParticle> outputParticleList = new Vector<ReconstructedParticle>();
+ // Loop over them, looking for multi-track particles
+ for (ReconstructedParticle part : inputParticleList) {
+ if (part.getTracks().size()<2) {
+ // Don't change
+ outputParticleList.add(part);
+ } else {
+ // Split into one particle per track
+ // Want to split like this:
+ // * Find where the clusters start (track calorimeter entry points)
+ // * Add pieces to each cluster
+ // For now do it in a fairly dumb way.
+ LocalHelixExtrapolationTrackClusterMatcher tmpExtrap = new LocalHelixExtrapolationTrackClusterMatcher();
+ tmpExtrap.process(event);
+ List<Cluster> unmatchedClusters = makeFlatClusterList(part);
+ Map<Track, Cluster> coreMap = new HashMap<Track, Cluster>();
+ Map<Track, List<Cluster>> fullMap = new HashMap<Track, List<Cluster>>();
+ for (Track tr : part.getTracks()) {
+ Cluster core = tmpExtrap.matchTrackToCluster(tr, unmatchedClusters);
+ if (core == null) { throw new AssertionError("Help: didn't find a cluster match"); }
+ coreMap.put(tr, core);
+ fullMap.put(tr, new Vector<Cluster>());
+ fullMap.get(tr).add(core);
+ unmatchedClusters.remove(core);
+ }
+ while (unmatchedClusters.size() > 0) {
+ findNextMatch(coreMap, fullMap, unmatchedClusters);
+ }
+ for (Track tr : part.getTracks()) {
+ BasicCluster masterCluster = new BasicCluster();
+ for (Cluster clus : fullMap.get(tr)) {
+ masterCluster.addCluster(clus);
+ }
+ double trueMass = ((ReconTrack)(tr)).getMCParticle().getMass();
+ Hep3Vector trackMomentum = new BasicHep3Vector(tr.getMomentum());
+ double energy = Math.sqrt(trackMomentum.magnitudeSquared() + trueMass*trueMass);
+ BaseReconstructedParticle outputParticle = new BaseReconstructedParticle(energy, trackMomentum);
+ outputParticle.addTrack(tr);
+ outputParticle.addCluster(masterCluster);
+ outputParticleList.add(outputParticle);
+ }
+ }
+ }
+ event.put(m_outputParticleListName, outputParticleList, ReconstructedParticle.class, 0);
+ }
+
+ void findNextMatch(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters) {
+ // Go round iteratively, first with tight cuts and then progressively looser...
+ double cutDotProduct = 0.8;
+ double cutDisplacementMagnitude = 30.0;
+ double cutRadius = 13.0;
+ while (unmatchedClusters.size() > 0) {
+ int numberRemoved = findNextMatch_cone(coreMap, fullMap, unmatchedClusters, cutDotProduct, cutDisplacementMagnitude);
+ if (numberRemoved > 0) { continue; }
+ numberRemoved = findNextMatch_radius(coreMap, fullMap, unmatchedClusters, cutRadius);
+ if (numberRemoved > 2) { continue; }
+ // OK, need to back off.
+ cutDisplacementMagnitude += 10.0;
+ cutRadius += 5.0;
+ }
+ }
+
+ int findNextMatch_cone(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters, double cutDotProduct, double cutDisplacementMagnitude) {
+
+ Map<Cluster, List<Track>> provisionalAssignments = new HashMap<Cluster, List<Track>>();
+ for (Track tr : coreMap.keySet()) {
+ Collection<Cluster> existingClusters = fullMap.get(tr);
+ for (Cluster seed : existingClusters) {
+ Hep3Vector seedPosition = estimatePosition(seed);
+ for (Cluster cand : unmatchedClusters) {
+ Hep3Vector candPosition = estimatePosition(cand);
+ Hep3Vector displacement = VecOp.sub(candPosition, seedPosition);
+ double dotProduct = VecOp.dot(VecOp.unit(seedPosition), VecOp.unit(displacement));
+ double displacementMagnitude = displacement.magnitude();
+ if (dotProduct > cutDotProduct && displacementMagnitude < cutDisplacementMagnitude) {
+ // Match!
+ if ( ! provisionalAssignments.keySet().contains(cand) ) {
+ provisionalAssignments.put(cand, new Vector<Track>());
+ }
+ provisionalAssignments.get(cand).add(tr);
+ }
+ }
+ }
+ }
+
+ // OK. Any uniquely matched clusters can be added.
+ // First, handle non-ambiguous cases
+ int countMatches = 0;
+ for (Cluster cand : provisionalAssignments.keySet()) {
+ List<Track> trackMatches = provisionalAssignments.get(cand);
+ if (trackMatches.size()<1) {
+ throw new AssertionError("bug");
+ } else if (trackMatches.size()==1) {
+ Track tr = trackMatches.get(0);
+ fullMap.get(tr).add(cand);
+ unmatchedClusters.remove(cand);
+ countMatches++;
+ }
+ }
+ // Now handle ambiguous cases:
+ for (Cluster cand : provisionalAssignments.keySet()) {
+ List<Track> trackMatches = provisionalAssignments.get(cand);
+ if (trackMatches.size()>1) {
+ // Choose closest hit-hit match
+ Track bestMatchTrack = null;
+ double bestMatchDistance = 0;
+ for (Track trackCand : trackMatches) {
+ for (Cluster trackClus : fullMap.get(trackCand)) {
+ double currentDist = proximity(trackClus, cand);
+ if (bestMatchTrack == null || currentDist < bestMatchDistance) {
+ bestMatchTrack = trackCand;
+ bestMatchDistance = currentDist;
+ }
+ }
+ }
+ fullMap.get(bestMatchTrack).add(cand);
+ unmatchedClusters.remove(cand);
+ countMatches++;
+ }
+ }
+
+ return countMatches;
+ }
+
+ int findNextMatch_radius(Map<Track, Cluster> coreMap, Map<Track, List<Cluster>> fullMap, List<Cluster> unmatchedClusters, double cutRadius) {
+ // Assign all hits within a short radius.
+ // If that fails to produce anything, look for single next match.
+ Track bestMatchTrack = null;
+ Cluster bestMatchCluster = null;
+ double bestMatchDistance = 0;
+ Map<Cluster,Track> pendingMatches = new HashMap<Cluster,Track>();
+ for (Cluster cand : unmatchedClusters) {
+ for (Track tr : coreMap.keySet()) {
+ for (Cluster trackClus : fullMap.get(tr)) {
+ double currentDist = proximity(trackClus, cand);
+ if (bestMatchTrack == null || currentDist < bestMatchDistance) {
+ bestMatchTrack = tr;
+ bestMatchCluster = cand;
+ bestMatchDistance = currentDist;
+ }
+ if (currentDist < cutRadius) {
+ pendingMatches.put(cand,tr);
+ }
+ }
+ }
+ }
+ if (pendingMatches.size()>0) {
+ for (Cluster clus : pendingMatches.keySet()) {
+ Track matchedTrack = pendingMatches.get(clus);
+ unmatchedClusters.remove(clus);
+ fullMap.get(matchedTrack).add(clus);
+ }
+ return pendingMatches.keySet().size();
+ } else {
+ if (bestMatchTrack==null || bestMatchCluster==null) { throw new AssertionError("stuck"); }
+ unmatchedClusters.remove(bestMatchCluster);
+ fullMap.get(bestMatchTrack).add(bestMatchCluster);
+ return 1;
+ }
+ }
+
+ List<Cluster> makeFlatClusterList(ReconstructedParticle part) {
+ Set<CalorimeterHit> allHits = new HashSet<CalorimeterHit>();
+ Set<Cluster> containedClusters = new HashSet<Cluster>();
+ for (Cluster clus : part.getClusters()) {
+ subMakeFlatClusterList(clus, allHits, containedClusters);
+ }
+ List<Cluster> allClusters = new Vector<Cluster>();
+ Set<CalorimeterHit> countedHits = new HashSet<CalorimeterHit>();
+ for (Cluster clus : containedClusters) {
+ allClusters.add(clus);
+ countedHits.addAll(clus.getCalorimeterHits());
+ }
+ // Any excess?
+ for (CalorimeterHit hit : allHits) {
+ if ( ! countedHits.contains(hit) ) {
+ BasicCluster newClus = new BasicCluster();
+ newClus.addHit(hit);
+ allClusters.add(newClus);
+ countedHits.add(hit);
+ }
+ }
+ // Crosscheck
+ List<CalorimeterHit> checkHits = new Vector<CalorimeterHit>();
+ for (Cluster clus : allClusters) {
+ checkHits.addAll(clus.getCalorimeterHits());
+ }
+ return allClusters;
+ }
+ void subMakeFlatClusterList(Cluster clus, Set<CalorimeterHit> allHits, Set<Cluster> containedClusters) {
+ allHits.addAll(clus.getCalorimeterHits());
+ if (clus.getClusters().size() == 0) {
+ containedClusters.add(clus);
+ } else {
+ for (Cluster subClus : clus.getClusters()) {
+ subMakeFlatClusterList(subClus, allHits, containedClusters);
+ }
+ }
+ }
+
+ protected double proximity(Cluster clus1, Cluster clus2) {
+ if (clus1.getCalorimeterHits().size()<1) { throw new AssertionError("Empty cluster"); }
+ if (clus2.getCalorimeterHits().size()<1) { throw new AssertionError("Empty cluster"); }
+ double minDist = 0;
+ boolean found = false;
+ for (CalorimeterHit hit1 : clus1.getCalorimeterHits()) {
+ Hep3Vector hitPosition1 = new BasicHep3Vector(hit1.getPosition());
+ for (CalorimeterHit hit2 : clus2.getCalorimeterHits()) {
+ Hep3Vector hitPosition2 = new BasicHep3Vector(hit2.getPosition());
+ double distance = VecOp.sub(hitPosition1,hitPosition2).magnitude();
+ if (distance<minDist || found==false) {
+ found = true;
+ minDist = distance;
+ }
+ }
+ }
+ return minDist;
+ }
+
+ protected Hep3Vector estimatePosition(Cluster clus) {
+ double[] runningSum = new double[3];
+ double count = clus.getCalorimeterHits().size();
+ for (CalorimeterHit hit : clus.getCalorimeterHits()) {
+ double[] hitPos = hit.getPosition();
+ for (int i=0; i<3; i++) {
+ runningSum[i] += (hitPos[i] / count);
+ }
+ }
+ return new BasicHep3Vector(runningSum);
+ }
+
+}