Commit in lcsim/src/org/lcsim/recon/util on MAIN
McTruthLinker.java+755added 1.1
A driver to provide truth links for tracks, clusters and reconstruced particles.
Based on the Marlin processor: MCTruthLinker

lcsim/src/org/lcsim/recon/util
McTruthLinker.java added at 1.1
diff -N McTruthLinker.java
--- /dev/null	1 Jan 1970 00:00:00 -0000
+++ McTruthLinker.java	13 Apr 2011 17:35:41 -0000	1.1
@@ -0,0 +1,755 @@
+package org.lcsim.recon.util;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+import org.lcsim.event.CalorimeterHit;
+import org.lcsim.event.Cluster;
+import org.lcsim.event.EventHeader;
+import org.lcsim.event.LCRelation;
+import org.lcsim.event.MCParticle;
+import org.lcsim.event.ReconstructedParticle;
+import org.lcsim.event.RelationalTable;
+import org.lcsim.event.SimCalorimeterHit;
+import org.lcsim.event.Track;
+import org.lcsim.event.TrackerHit;
+import org.lcsim.event.MCParticle.SimulatorStatus;
+import org.lcsim.event.base.BaseLCRelation;
+import org.lcsim.event.base.BaseRelationalTable;
+import org.lcsim.util.Driver;
+import org.lcsim.util.lcio.LCIOConstants;
+import org.lcsim.util.lcio.LCIOUtil;
+
+/**
+ * A Driver to create several LCRelations between high level reconstructed
+ * objects and their true mc particles.
+ * <p>
+ * By default three LCRelations are created
+ * <ul><li><b>Tracks to MCParticles:</b><br>
+ * This requires the name of the track collection and an LCRelation between
+ * TrackerHits and MCParticles.</li>
+ * <li><b>Clusters to MCParticles:</b><br>
+ * This requires the name of the cluster collection and an LCRelation between
+ * CalorimeterHits and SimCalorimeterHits.</li>
+ * <li><b>PFOs to MCParticles:</b><br>
+ * This requires the name of the PFO collection and the creation of both
+ * LCRelations described above.</li></ul>
+ * Putting any of these into the event can be prevented by setting the
+ * respective name to an empty String.
+ * <p>
+ * The weights of these relations are based on fractions contributed by
+ * the mc particle. For the track relation it is based on the fraction of hits
+ * and for clusters it is based on the fraction of energy. Non-charged PFOs
+ * use the weight of the clusters, while charged PFOs use a combined weight
+ * of the tracks and the clusters, based on a global track to cluster weight.
+ * As a default only the track relations are used for charged PFOs.
+ * <p>
+ * Instead of the simple fraction-based weights described above a weight
+ * based on a Tanimoto metric can be used. Then also the hits produced by
+ * a mc particle which are not part of the reconstructed object are taken
+ * into account. This leads to a lower weight for missed hits.
+ * <p>
+ * For all objects, only the relation to the MCParticle which has the highest
+ * weight is kept. Setting the fullRecoRelation switch to true will keep all
+ * relations instead.
+ * <p>
+ * By default a reduced set of MCParticles is created. This skimmed list
+ * contains only those MCParticles created by the generator (intermediate
+ * and final state particles) and emitted bremsstrahlung photons. In addition
+ * all particles which are of a pre-defined set of particle types and which
+ * decay in flight in the tracking system and their intermediate daughters
+ * are kept. By default these particle types are gamma, pi0 and K0s.
+ * All relations which would point to an MCParticle not contained in this
+ * reduced list point to their closest ancestor which is in this list instead.
+ * Again, this behavior can be switched off, by setting the name of the 
+ * skimmed mc particle collection to an empty String. 
+ * 
+ * @author <a href="mailto:[log in to unmask]">Christian Grefe</a>
+ */
+public class McTruthLinker extends Driver {
+
+	protected String trackHitMcRelationName = "HelicalTrackMCRelations";
+	protected String trackCollectionName = EventHeader.TRACKS;
+	protected String trackMcRelationName = "TrackMCTruthLink";
+	protected String caloHitsimHitRelationName = "CalorimeterHitRelations";
+	protected String clusterCollectionName = "ReconClusters";
+	protected String clusterMcRelationName = "ClusterMCTruthLink";
+	protected String pfoCollectionName = "PandoraPFOCollection";
+	protected String pfoMcRelationName = "RecoMCTruthLink";
+	protected String mcParticleCollectionName = EventHeader.MC_PARTICLES;
+	protected String mcParticlesSkimmedName = "MCParticlesSkimmed";
+	protected double pfoTrackWeight = 1.0;
+	protected double pfoClusterWeight = 0.0;
+	protected boolean fullRecoRelation = false;
+	protected boolean useTanimotoDistance = false;
+	protected boolean useSkimmedMcParticles = true;
+	protected List<MCParticle> mcParticlesSkimmed;
+	protected List<Integer> keepDaughtersPDGID = new ArrayList<Integer>();
+	protected Map<MCParticle, MCParticle> mcParticleToSkimmed;
+	protected double daughterEnergyCut = 0.01;
+	
+	// -------------------- Constructors --------------------
+	
+	public McTruthLinker() {
+		keepDaughtersPDGID.add(22);  // gamma
+		keepDaughtersPDGID.add(111); // pi0
+		keepDaughtersPDGID.add(310); // K0s
+	}
+	
+	
+	
+	// -------------------- Driver Interface Methods --------------------
+	
+	@Override
+	protected void startOfData() {
+		if (mcParticlesSkimmedName.equals("")) {
+			this.useSkimmedMcParticles = false;
+		}
+	}
+	
+	@Override
+	protected void process(EventHeader event) {
+		
+		List<LCRelation> trackMcRelation = null;
+		List<LCRelation> caloHitMcRelation = null;
+		List<LCRelation> clusterMcRelation = null;
+		List<LCRelation> pfoMcRelation = null;
+		mcParticlesSkimmed = null;
+		mcParticleToSkimmed = null;
+		
+		// skimmed mc particles
+		if (useSkimmedMcParticles) {
+			try {
+				List<MCParticle> mcParticles = event.get(MCParticle.class, mcParticleCollectionName);
+				mcParticlesSkimmed = createSkimmedMcParticleList(mcParticles);
+				mcParticleToSkimmed = fillMcParticleToSkimmedMap(mcParticles, mcParticlesSkimmed);
+				int flags = event.getMetaData(mcParticles).getFlags();
+				flags = LCIOUtil.bitSet(flags, LCIOConstants.BITSubset, true);
+				event.put(mcParticlesSkimmedName, mcParticlesSkimmed, MCParticle.class, flags);
+				print(HLEVEL_NORMAL, "Added skimmed mc particles \""+mcParticlesSkimmedName+"\" to the event.");
+			} catch (IllegalArgumentException e) {
+				print(HLEVEL_DEFAULT,
+						"WARNING: no skimmed mc particle collection created.\n" +
+						"e.getMessage()",
+						true);
+			}
+		}
+		
+		// track to mc particle relation
+		if (!trackHitMcRelationName.equals("") && !trackCollectionName.equals("")) {
+			try {
+				List<Track> tracks = event.get(Track.class, trackCollectionName);
+				List<LCRelation> trackHitMcRelation = event.get(LCRelation.class, trackHitMcRelationName);
+				trackMcRelation = createTrackMcRelation(tracks, trackHitMcRelation);
+				if (!trackMcRelationName.equals("")) {
+					event.put(trackMcRelationName, trackMcRelation, LCRelation.class, 0);
+					print(HLEVEL_NORMAL, "Added track to mc particle relations \""+trackMcRelationName+"\" to the event.");
+				}
+			} catch (IllegalArgumentException e) {
+				print(HLEVEL_DEFAULT,
+						"WARNING: no track to mc particle relation created.\n" +
+						"e.getMessage()",
+						true);
+			}
+		}
+		
+		// calorimeter hit to mc particle relation
+		if (!caloHitsimHitRelationName.equals("")) { 
+			try {
+				caloHitMcRelation = createCaloHitMcRelation(event.get(LCRelation.class, caloHitsimHitRelationName));
+			} catch (IllegalArgumentException e) {
+				print(HLEVEL_DEFAULT,
+						"WARNING: no calorimeter hit to mc particle relation created.\n" +
+						"e.getMessage()",
+						true);
+			}
+		}
+		
+		// cluster to mc particle relation
+		if (!clusterCollectionName.equals("")) {
+			try {
+				List<Cluster> clusters = event.get(Cluster.class, clusterCollectionName);
+				clusterMcRelation = createClusterMcRelation(clusters, caloHitMcRelation);
+				if (!clusterMcRelationName.equals("")) {
+					event.put(clusterMcRelationName, clusterMcRelation, LCRelation.class, 0);
+					print(HLEVEL_NORMAL, "Added cluster to mc particle relations \""+clusterMcRelationName+"\" to the event.");
+				}
+			} catch (IllegalArgumentException e) {
+				print(HLEVEL_DEFAULT,
+						"WARNING: no cluster to mc particle relation created.\n" +
+						"e.getMessage()",
+						true);
+			}
+		}
+		
+		// PFO to mc particle relation
+		if (!pfoCollectionName.equals("")) {
+			try {
+				List<ReconstructedParticle> PFOs = event.get(ReconstructedParticle.class, pfoCollectionName);
+				pfoMcRelation = createPfoMcRelation(PFOs, trackMcRelation, clusterMcRelation);
+				if (!pfoMcRelationName.equals("")) {
+					event.put(pfoMcRelationName, pfoMcRelation, LCRelation.class, 0);
+					print(HLEVEL_NORMAL, "Added PFO to mc particle relations \""+pfoMcRelationName+"\" to the event.");
+				}
+			} catch (IllegalArgumentException e) {
+				print(HLEVEL_DEFAULT,
+						"WARNING: no PFO to mc particle relation created.\n" +
+						"e.getMessage()",
+						true);
+			}
+		}
+	}
+	
+	
+	
+	// -------------------- Setter Methods --------------------
+	
+	public void setFullRecoRelation(boolean fullRecoRelation) {
+		this.fullRecoRelation = fullRecoRelation;
+	}
+	
+	public void setUseTanimotoDistance(boolean useTanimotoDistance) {
+		this.useTanimotoDistance = useTanimotoDistance;
+	}
+	
+	public void setPfoTrackWeight(double pfoTrackWeight) throws IllegalArgumentException {
+		if (pfoTrackWeight < 0) throw new IllegalArgumentException("PFO track weight can not be negative.");
+		this.pfoTrackWeight = pfoTrackWeight;
+	}
+	
+	public void setPfoClusterWeight(double pfoClusterWeight) throws IllegalArgumentException {
+		if (pfoTrackWeight < 0) throw new IllegalArgumentException("PFO cluster weight can not be negative.");
+		this.pfoClusterWeight = pfoClusterWeight;
+	}
+	
+	public void setTrackHitMcRelationName(String trackHitMcRelationName) {
+		this.trackHitMcRelationName = trackHitMcRelationName;
+	}
+	
+	public void setTrackCollectionName(String trackCollectionName) {
+		this.trackCollectionName = trackCollectionName;
+	}
+	
+	public void setTrackMcRelationName(String trackMcRelationName) {
+		this.trackMcRelationName = trackMcRelationName;
+	}
+	
+	public void setCaloHitsimHitRelationName(String caloHitsimHitRelationName) {
+		this.caloHitsimHitRelationName = caloHitsimHitRelationName;
+	}
+	
+	public void setClusterCollectionName(String clusterCollectionName) {
+		this.clusterCollectionName = clusterCollectionName;
+	}
+	
+	public void setClusterMcRelationName(String clusterMcRelationName) {
+		this.clusterMcRelationName = clusterMcRelationName;
+	}
+	
+	public void setPfoCollectionName(String pfoCollectionName) {
+		this.pfoCollectionName = pfoCollectionName;
+	}
+	
+	public void setPfoMcRelationName(String pfoMcRelationName) {
+		this.pfoMcRelationName = pfoMcRelationName;
+	}
+	
+	public void setMcParticleCollectionName(String mcParticleCollectionName) {
+		this.mcParticleCollectionName = mcParticleCollectionName;
+	}
+	
+	public void setMcParticlesSkimmedName(String mcParticlesSkimmedName) {
+		this.mcParticlesSkimmedName = mcParticlesSkimmedName;
+	}
+	
+	public void setKeepDaughtersPDGID(Integer[] keepDaughtersPDGID) {
+		this.keepDaughtersPDGID.clear();
+		this.keepDaughtersPDGID.addAll(Arrays.asList(keepDaughtersPDGID));
+	}
+	
+	public void setDaughterEnergyCut(double daughterEnergyCut) {
+		this.daughterEnergyCut = daughterEnergyCut;
+	}
+	
+	
+	
+	// -------------------- Protected Methods --------------------
+	
+	/**
+	 * Creates a list of skimmed mc particles which are kept together
+	 * with all their ancestors. First of all, all the particles that
+	 * are created by the generator (IntermediateState or FinalState)
+	 * are kept. In addition bremsstrahlung photons created by these
+	 * particles are kept. Finally all the particles from a given list
+	 * (default: gamma, pi0, K0s) are kept together with their direct
+	 * daughters.
+	 */
+	protected List<MCParticle> createSkimmedMcParticleList(List<MCParticle> mcParticles) {
+		
+		List<MCParticle> skimmedMcParticles = new ArrayList<MCParticle>();
+		
+		for (MCParticle mcParticle : mcParticles) {
+			SimulatorStatus simStatus = mcParticle.getSimulatorStatus();
+			if (mcParticle.getGeneratorStatus() == MCParticle.INTERMEDIATE ) {
+				// first add all intermediate particles
+				addMcParticleWithParents(mcParticle, skimmedMcParticles);
+			}
+			if (mcParticle.getGeneratorStatus() == MCParticle.FINAL_STATE) {
+				// add all mc particles created by the generator
+				addMcParticleWithParents(mcParticle, skimmedMcParticles);
+				// check if there is some interaction in the tracking region
+				if (simStatus.isDecayedInCalorimeter()) {
+					// keep bremsstrahlung
+					for (MCParticle daughter : mcParticle.getDaughters()) {
+						if (daughter.getPDGID() == 22 && daughter.getEnergy() > daughterEnergyCut && !daughter.getSimulatorStatus().isBackscatter()) {
+							addMcParticleWithParents(daughter, skimmedMcParticles);
+						}
+					}
+				}
+				//
+			} else if (mcParticle.getSimulatorStatus().isDecayedInTracker()) {
+				// now add all daughters of the particles that decayed in flight and should be kept
+				if (keepDaughtersPDGID.contains(mcParticle.getPDGID())) {
+					for (MCParticle daughter : mcParticle.getDaughters()) {
+						if (daughter.getEnergy() > daughterEnergyCut && !daughter.getSimulatorStatus().isBackscatter()) {
+							addMcParticleWithParents(daughter, skimmedMcParticles);
+						}
+					}
+				}
+				
+			}
+		}
+		
+		print(HLEVEL_NORMAL, "Keeping "+skimmedMcParticles.size()+" of "+mcParticles.size()+" mc particles in skimmed list.");
+		
+		return skimmedMcParticles;
+	}
+	
+	/**
+	 * Fills a map connecting an mc particle with its closest ancestor
+	 * that is present in the skimmed mc particle list.
+	 * If no suitable ancestor is found the map is filled with null for
+	 * that mc particle.
+	 * @param mcParticles The list of all mc particles
+	 * @param skimmedMcParticles A subset of the mc particles
+	 * @return A mapping between all mc particles and their closest ancestor present in the skimmed mc particles
+	 */
+	protected Map<MCParticle, MCParticle> fillMcParticleToSkimmedMap(List<MCParticle> mcParticles, List<MCParticle> skimmedMcParticles) {
+		
+		Map<MCParticle, MCParticle> mcParticleToSkimmedMap = new HashMap<MCParticle, MCParticle>();
+		
+		for (MCParticle mcParticle : mcParticles) {
+			MCParticle ancestor = findMcParticleAncestor(mcParticle, skimmedMcParticles);
+			mcParticleToSkimmedMap.put(mcParticle, ancestor);
+			if (mcParticle != ancestor) {
+				print(HLEVEL_FULL,
+						"Warning: Rejecting mc particle." +
+						"\tEnergy: "+mcParticle.getEnergy()+"\n" +
+						"\tCharge: "+mcParticle.getCharge()+"\n" +
+						"\tPDGID: "+mcParticle.getPDGID()+"\n" +
+						"\tGenStatus: "+mcParticle.getGeneratorStatus()+"\n" +
+						"\tCreated in simulation: "+mcParticle.getSimulatorStatus().isCreatedInSimulation()+"\n" +
+						"\tBackscatter: "+mcParticle.getSimulatorStatus().isBackscatter()+"\n" +
+						"\tDecay in calorimeter: "+mcParticle.getSimulatorStatus().isDecayedInCalorimeter()+"\n" +
+						"\tDecay in tracker: "+mcParticle.getSimulatorStatus().isDecayedInTracker()+"\n" +
+						"\tStopped: "+mcParticle.getSimulatorStatus().isStopped()+"\n" +
+						"\tMother: "+mcParticle.getParents().get(0).getPDGID(),
+						true);
+			}
+		}
+		return mcParticleToSkimmedMap;
+	}
+	
+	/**
+	 * Creates the relations from tracks to mc particles by using a list
+	 * of LCRelations from hits to mc particles.
+	 * In case of a skimmed mc particle list the relations are pointing
+	 * to the closest ancestor present in the skimmed list.
+	 * <p>
+	 * The relations are weighted by the fraction of hits belonging to
+	 * a certain mc particle (N_{match}/N_{track}).
+	 * <p>
+	 * In case of Tanimoto distance also the total number of hits produced
+	 * by the mc particle are taken into account. It gives less weight to
+	 * tracks that miss true hits. The weight is then calculated as
+	 * 1 - (N_{track}+N_{mc}-2*N_{match})/(N_{track}+N_{mc}-2*N_{match).
+	 * @param tracks The list of tracks to be truth linked
+	 * @param trackHitMcRelation The LCRelations between track hits and mc particles
+	 * @return The weighted LCRelations between tracks and mc particles
+	 */
+	protected List<LCRelation> createTrackMcRelation(List<Track> tracks, List<LCRelation> trackHitMcRelation) {
+		
+		if (trackHitMcRelation == null) {
+			throw new IllegalArgumentException("No tracker hit to mc relations given.");
+		}
+		
+		RelationalTable<TrackerHit, MCParticle> trackHitMcRelationTable = createRelationalTable(trackHitMcRelation);
+		List<LCRelation> trackMcRelation = new ArrayList<LCRelation>();
+		
+		for (Track track : tracks) {
+			// Store number of hits contributed by each mc particle
+			Map<MCParticle, Integer> mcParticleContribution = new HashMap<MCParticle, Integer>();
+			List<TrackerHit> trackHitsList = track.getTrackerHits();
+			double trackHits = trackHitsList.size();
+			double sumOfWeights = 0;
+			for (TrackerHit trackHit : trackHitsList) {
+				for (MCParticle mcParticle : trackHitMcRelationTable.allFrom(trackHit)) {
+					if (useSkimmedMcParticles) mcParticle = mcParticleToSkimmed.get(mcParticle);
+					if (mcParticleContribution.containsKey(mcParticle)) {
+						mcParticleContribution.put(mcParticle, mcParticleContribution.get(mcParticle)+1);
+					} else {
+						mcParticleContribution.put(mcParticle, 1);
+					}
+				}
+			}
+			mcParticleContribution = sortMapByHighestValue(mcParticleContribution);
+			for (MCParticle mcParticle : mcParticleContribution.keySet()) {
+				double weight = 0.0;
+				double recoHits = mcParticleContribution.get(mcParticle);
+				if (useTanimotoDistance) {
+					double trueHits = trackHitMcRelationTable.allTo(mcParticle).size();
+					weight = 1 - (trackHits+trueHits-2*recoHits)/(trackHits+trueHits-recoHits);
+				} else {
+					weight = recoHits/trackHits;
+				}
+				sumOfWeights += weight;
+				trackMcRelation.add(new BaseLCRelation(track, mcParticle, weight));
+				print(HLEVEL_FULL, "Added a track to mc particle relation with weight "+weight+".");
+				if (!fullRecoRelation) break;
+			}
+			print(HLEVEL_HIGH, "Total weight of track contributions is "+sumOfWeights+".");
+		}
+		
+		print(HLEVEL_NORMAL, "Created "+trackMcRelation.size()+" track to mc particle relations.");
+		
+		return trackMcRelation;
+	}
+	
+	/**
+	 * Creates the relations from calorimeter hits to mc particles
+	 * by using a list of LCRelations from CalorimeterHits to
+	 * SimCalorimeterHits and the intrinsic link to mc particles of
+	 * the sim hits.
+	 * <p>
+	 * The produced relations are weighted by the energy fraction
+	 * contributed by the mc particle to the SimCalorimeterHit
+	 * (E_{MC,Hit}/E_{Hit})
+	 * @param caloHitSimHitRelation The relations between CalorimeterHits and SimCalorimeterHits
+	 * @return The weighted LCRelations between CalorimeterHits and MCParticles
+	 */
+	protected List<LCRelation> createCaloHitMcRelation(List<LCRelation> caloHitSimHitRelation) {
+		
+		List<LCRelation> caloHitMcRelation = new ArrayList<LCRelation>();
+		
+		for (LCRelation relation : caloHitSimHitRelation) {
+			CalorimeterHit digiHit = (CalorimeterHit) relation.getFrom();
+			SimCalorimeterHit simHit = (SimCalorimeterHit) relation.getTo();
+			double hitEnergy = simHit.getRawEnergy();
+			double sumOfWeights = 0;
+			for (int i = 0; i < simHit.getMCParticleCount(); i++) {
+				double weight = simHit.getContributedEnergy(i)/hitEnergy;
+				sumOfWeights += weight;
+				caloHitMcRelation.add(new BaseLCRelation(digiHit, simHit.getMCParticle(i), weight));
+				print(HLEVEL_FULL, "Added a calorimeter hit to mc particle relation with weight "+weight+".");
+			}
+			print(HLEVEL_FULL, "Total weight of calorimeter hit contributions is "+sumOfWeights+".");
+		}
+		
+		print(HLEVEL_NORMAL, "Created "+caloHitMcRelation.size()+" calorimeter hit to mc particle relations.");
+		
+		return caloHitMcRelation;
+	}
+	
+	/**
+	 * Creates the relations from Clusters to MCParticles by using
+	 * a list of LCRelations from CalorimeterHits to MCParticles.
+	 * <p>
+	 * The produced relations are weighted by the energy fraction
+	 * contributed by the MCParticle to the Cluster
+	 * (E_{MC,Cluster}/E_{Cluster})
+	 * @param clusters The list of clusters to be truth linked
+	 * @param caloHitMcRelation The relations between CalorimeterHits and MCParticles
+	 * @return The weighted LCRelations between Clusters and MCParticles
+	 * @throws IllegalArgumentException
+	 */
+	protected List<LCRelation> createClusterMcRelation(List<Cluster> clusters, List<LCRelation> caloHitMcRelation) throws IllegalArgumentException {
+		
+		if (caloHitMcRelation == null) {
+			throw new IllegalArgumentException("No calorimeter hit to mc relations given.");
+		}
+		
+		RelationalTable<CalorimeterHit, MCParticle> caloHitMcRelationTable = createRelationalTable(caloHitMcRelation);
+		List<LCRelation> clusterMcRelation = new ArrayList<LCRelation>();
+		
+		for (Cluster cluster : clusters) {
+			double sumOfWeights = 0;
+			double clusterEnergy = cluster.getEnergy();
+			Map<MCParticle,Double> mcParticlesWeight = new HashMap<MCParticle, Double>();
+			for (CalorimeterHit hit : cluster.getCalorimeterHits()) {
+				double hitEnergy = hit.getCorrectedEnergy();
+				double hitWeight = hitEnergy/clusterEnergy;
+				Map<MCParticle,Double> hitMcParticlesWeight = caloHitMcRelationTable.allFromWithWeights(hit);
+				for (MCParticle mcParticle : hitMcParticlesWeight.keySet()) {
+					// TODO implement optional use of Tanimoto distance 
+					double weight = hitWeight*hitMcParticlesWeight.get(mcParticle);
+					if (useSkimmedMcParticles) mcParticle = mcParticleToSkimmed.get(mcParticle);
+					if (mcParticlesWeight.containsKey(mcParticle)) {
+						mcParticlesWeight.put(mcParticle, mcParticlesWeight.get(mcParticle) + weight);
+					} else {
+						mcParticlesWeight.put(mcParticle, weight);
+					}
+				}
+			}
+			mcParticlesWeight = sortMapByHighestValue(mcParticlesWeight);
+			for (MCParticle mcParticle : mcParticlesWeight.keySet()) {
+				double weight = mcParticlesWeight.get(mcParticle);
+				sumOfWeights += weight;
+				clusterMcRelation.add(new BaseLCRelation(cluster, mcParticle, weight));
+				print(HLEVEL_FULL, "Added a cluster to mc particle relation with weight "+weight+".");
+				if (!fullRecoRelation) break;
+			}
+			print(HLEVEL_HIGH, "Total weight of cluster contributions is "+sumOfWeights+".");
+		}
+		
+		print(HLEVEL_NORMAL, "Created "+clusterMcRelation.size()+" cluster to mc particle relations.");
+		
+		return clusterMcRelation;
+	}
+	
+	/**
+	 * Creates the relations from PFOs to MCParticles by using
+	 * a list of LCRelations from Tracks to MCParticles and a
+	 * second list of LCRelations from Clusters to MCParticles.
+	 * <p>
+	 * In case of a non-charged PFO the relation is weighted using
+	 * the weights from the contributing Cluster to MCParticle
+	 * relations.
+	 * <p>
+	 * For charged PFOs the weight of the relations are calculated
+	 * separately for tracks and clusters and then combined depending
+	 * on a global track to cluster weight. By default the track
+	 * weight is 1 and the cluster weight is 0. Thus, only the
+	 * relation via track is taken into account.
+	 * @param recoParticles The list of PFOs to be truth linked
+	 * @param trackMcRelation The relations between Tracks and MCParticles
+	 * @param clusterMcRelation The relations between Clusters and MCParticles
+	 * @return The weighted LCRelations between PFOs and MCParticles
+	 * @throws IllegalArgumentException
+	 */
+	protected List<LCRelation> createPfoMcRelation(List<ReconstructedParticle> recoParticles, List<LCRelation> trackMcRelation, List<LCRelation> clusterMcRelation) throws IllegalArgumentException {
+		
+		if (trackMcRelation == null) {
+			throw new IllegalArgumentException("No track to mc relations given.");
+		}
+		if (clusterMcRelation == null) {
+			throw new IllegalArgumentException("No cluster to mc relations given.");
+		}
+		
+		RelationalTable<Track, MCParticle> trackMcRelationTable = createRelationalTable(trackMcRelation);
+		RelationalTable<Cluster, MCParticle> clusterMcRelationTable = createRelationalTable(clusterMcRelation);
+		List<LCRelation> pfoMcRelation = new ArrayList<LCRelation>();
+		
+		for (ReconstructedParticle recoParticle : recoParticles) {
+			double sumOfWeights = 0;
+			int pfoTrackHits = 0;
+			double pfoEnergy = recoParticle.getEnergy();
+			double thisPfoClusterWeight = pfoClusterWeight;
+			double trackClusterNormalization = pfoTrackWeight+pfoClusterWeight;
+			Map<MCParticle, Double> mcParticlesWeight = new HashMap<MCParticle, Double>();
+			// if PFO has tracks use them for truth link and ignore cluster
+			if (pfoTrackWeight != 0) {
+				for (Track track : recoParticle.getTracks()) {
+					pfoTrackHits += track.getTrackerHits().size();
+				}
+				for (Track track : recoParticle.getTracks()) {
+					Map<MCParticle, Double> trackMcParticlesWeight = trackMcRelationTable.allFromWithWeights(track);
+					double trackWeight = track.getTrackerHits().size()/(double)pfoTrackHits;
+					// weigh the contribution by track to cluster weight
+					trackWeight *= pfoTrackWeight/trackClusterNormalization;
+					for (MCParticle mcParticle : trackMcParticlesWeight.keySet()) {
+						double weight = trackWeight * trackMcParticlesWeight.get(mcParticle);
+						if (useSkimmedMcParticles) mcParticle = mcParticleToSkimmed.get(mcParticle);
+						if (mcParticlesWeight.containsKey(mcParticle)) {
+							mcParticlesWeight.put(mcParticle, mcParticlesWeight.get(mcParticle) + weight);
+						} else {
+							mcParticlesWeight.put(mcParticle, weight);
+						}
+					}
+				}
+			}
+			// If no tracks attached, only use clusters
+			if (pfoTrackHits == 0) {
+				thisPfoClusterWeight = 1.0;
+				trackClusterNormalization = 1.0;
+			}
+			// if PFO has no tracks use clusters for truth link
+			if (thisPfoClusterWeight != 0) {
+				for (Cluster cluster : recoParticle.getClusters()) {
+					Map<MCParticle, Double> clusterMcParticlesWeight = clusterMcRelationTable.allFromWithWeights(cluster);
+					double clusterWeight = cluster.getEnergy()/pfoEnergy;
+					// weigh the contribution by cluster to cluster weight
+					clusterWeight *= thisPfoClusterWeight/trackClusterNormalization;
+					for (MCParticle mcParticle : clusterMcParticlesWeight.keySet()) {
+						double weight = clusterWeight*clusterMcParticlesWeight.get(mcParticle);
+						if (useSkimmedMcParticles) mcParticle = mcParticleToSkimmed.get(mcParticle);
+						if (mcParticlesWeight.containsKey(mcParticle)) {
+							mcParticlesWeight.put(mcParticle, mcParticlesWeight.get(mcParticle) + weight);
+						} else {
+							mcParticlesWeight.put(mcParticle, weight);
+						}
+					}
+				}
+			}
+			mcParticlesWeight = sortMapByHighestValue(mcParticlesWeight);
+			for (MCParticle mcParticle : mcParticlesWeight.keySet()) {
+				double weight = mcParticlesWeight.get(mcParticle);
+				// need to normalize to total number of track hits
+				sumOfWeights += weight;
+				pfoMcRelation.add(new BaseLCRelation(recoParticle, mcParticle, weight));
+				print(HLEVEL_FULL, "Added a PFO to mc particle relation with weight "+weight+".\n" +
+						"\tEnergy: "+mcParticle.getEnergy()+"\n" +
+						"\tCharge: "+mcParticle.getCharge()+"\n" +
+						"\tPDGID: "+mcParticle.getPDGID());
+				if (!fullRecoRelation) break;
+			}
+			print(HLEVEL_HIGH, "Total weight of PFO contributions is "+sumOfWeights+".");
+		}
+		
+		print(HLEVEL_NORMAL, "Created "+pfoMcRelation.size()+" PFO to mc particle relations.");
+		
+		return pfoMcRelation;
+	}
+	
+	
+	
+	// -------------------- Static Methods --------------------
+	// TODO These can most likely move to a more general class
+	
+	/**
+	 * Helper method to write a message to the output stream if the
+	 * histogram level set for the driver is equal or higher than
+	 * the given value.
+	 * @param histogramLevel The level at which the message is printed
+	 * @param message The message, which will be printed to the stream
+	 */
+	protected void print(int histogramLevel, String message) {
+		print(histogramLevel, message, false);
+	}
+	
+	/**
+	 * Helper method to write a message to the output stream if the
+	 * histogram level set for the driver is equal or higher than
+	 * the given value.
+	 * @param histogramLevel The level at which the message is printed
+	 * @param message The message, which will be printed to the stream
+	 * @param error If true, writes to error stream instead of standard
+	 */
+	protected void print(int histogramLevel, String message, boolean error) {
+		if (getHistogramLevel() >= histogramLevel) {
+			message = getName()+": "+message;
+			if (error) {
+				System.err.println(message);
+			} else System.out.println(message);
+		}
+	}
+	
+	/**
+	 * Adds an mc particle to a list of mc particles if it is not yet in the list.
+	 * Also adds all its ancestors recursively to the same list.
+	 * @param mcParticle The mc particle to be added to the list
+	 * @param mcParticles The list to add the mc particle to
+	 */
+	protected void addMcParticleWithParents(MCParticle mcParticle, List<MCParticle> mcParticles) {
+		if (!mcParticles.contains(mcParticle)) {
+			mcParticles.add(mcParticle);
+			print(HLEVEL_FULL, "Adding mc particle to skimmed list.\n" +
+					"\tEnergy: "+mcParticle.getEnergy()+"\n" +
+					"\tCharge: "+mcParticle.getCharge()+"\n" +
+					"\tPDGID: "+mcParticle.getPDGID()+"\n" +
+					"\tGenStatus: "+mcParticle.getGeneratorStatus()+"\n" +
+					"\tSimStatus: "+mcParticle.getSimulatorStatus().getValue() );
+			for (MCParticle parent : mcParticle.getParents()) {
+				addMcParticleWithParents(parent, mcParticles);
+			}
+		}
+	}
+	
+	/**
+	 * Finds the first ancestor of a given mc particle within a list of mc particles.
+	 * Used to find the relevant particle in the skimmed list, when trying to find out
+	 * which true particle caused a hit.
+	 * @param mcParticle The mc particle
+	 * @param mcParticles The list of mc particles containing possible ancestors
+	 * @return The mc particle ancestor. Null if none is found.
+	 */
+	protected MCParticle findMcParticleAncestor(MCParticle mcParticle, List<MCParticle> mcParticles) {
+		MCParticle ancestor = null;
+		
+		if (mcParticles.contains(mcParticle)) {
+			ancestor = mcParticle;
+		} else {
+			List<MCParticle> parents = mcParticle.getParents();
+			if (parents.size() > 0) {
+				// just look for the first ancestor here if multiple are present
+				ancestor = findMcParticleAncestor(parents.get(0), mcParticles);
+			}
+		}
+		if (ancestor == null) {
+			print(HLEVEL_DEFAULT,
+					"Warning: no ancestor found in mc particle list." +
+					"\tEnergy: "+mcParticle.getEnergy()+"\n" +
+					"\tCharge: "+mcParticle.getCharge()+"\n" +
+					"\tPDGID: "+mcParticle.getPDGID()+"\n" +
+					"\tGenStatus: "+mcParticle.getGeneratorStatus()+"\n" +
+					"\tSimStatus: "+mcParticle.getSimulatorStatus().getValue(),
+					true);
+		}
+		return ancestor;
+	}
+	
+	/**
+	 * Converts a List of LCRelations (one to one relations with weights) into
+	 * a RelationalTable (many to many relations with weights). This improves
+	 * access if the relations in the LCRelation are actually many to many
+	 * or one to many relations described with multiple one to one relations.
+	 * @param relations A list of LCRelations
+	 * @return A RelationalTable with the same content as the given list
+	 */
+	public static <F,T> RelationalTable<F,T> createRelationalTable (List<LCRelation> relations) {
+		RelationalTable<F, T> relationalTable = new BaseRelationalTable<F, T>();
+		for (LCRelation relation : relations) {
+			relationalTable.add((F)relation.getFrom(), (T)relation.getTo(), relation.getWeight());
+		}
+		return relationalTable;
+	}
+	
+	/**
+	 * Creates a map with its keys sorted by its values in descending order
+	 * from an existing map. The values have to be comparable.
+	 * @param map The original map which should be sorted 
+	 * @return A new map with keys sorted by values
+	 */
+	public static Map sortMapByHighestValue(Map map) {
+	     List list = new LinkedList(map.entrySet());
+	     Collections.sort(list, new Comparator() {
+	    	 public int compare(Object o1, Object o2) {
+	               return - ((Comparable) ((Map.Entry) (o1)).getValue()).compareTo(((Map.Entry) (o2)).getValue());
+	          }
+	     });
+
+	    Map result = new LinkedHashMap();
+	    for (Iterator it = list.iterator(); it.hasNext();) {
+	        Map.Entry entry = (Map.Entry)it.next();
+	        result.put(entry.getKey(), entry.getValue());
+	    }
+	    return result;
+	} 
+}
CVSspam 0.2.8